From 710940608cfe57da1e3750e45ce84152d0bb0f4d Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Wed, 24 Jan 2024 09:42:25 +0800 Subject: [PATCH] feat: sync whisper.cpp (#186) --- cpp/coreml/whisper-encoder.mm | 2 +- cpp/ggml-alloc.c | 52 +- cpp/ggml-alloc.h | 4 +- cpp/ggml-backend-impl.h | 72 +- cpp/ggml-backend.c | 899 +++++-- cpp/ggml-backend.h | 88 +- cpp/ggml-impl.h | 3 + cpp/ggml-metal-whisper.metal | 1594 +++++++++--- cpp/ggml-metal.h | 60 +- cpp/ggml-metal.m | 3991 +++++++++++++++--------------- cpp/ggml-quants.c | 2651 ++++++++++++++++---- cpp/ggml-quants.h | 40 +- cpp/ggml.c | 1000 ++++++-- cpp/ggml.h | 137 +- cpp/whisper.cpp | 204 +- example/ios/Podfile.lock | 4 +- scripts/bootstrap.sh | 1 - scripts/ggml-metal.m.patch | 48 +- scripts/whisper-encoder.mm.patch | 14 - scripts/whisper.cpp.patch | 16 +- src/version.json | 2 +- whisper.cpp | 2 +- 22 files changed, 7291 insertions(+), 3593 deletions(-) delete mode 100644 scripts/whisper-encoder.mm.patch diff --git a/cpp/coreml/whisper-encoder.mm b/cpp/coreml/whisper-encoder.mm index 49ba8a8..81a5a6a 100644 --- a/cpp/coreml/whisper-encoder.mm +++ b/cpp/coreml/whisper-encoder.mm @@ -24,7 +24,7 @@ // select which device to run the Core ML model on MLModelConfiguration *config = [[MLModelConfiguration alloc] init]; - //config.computeUnits = MLComputeUnitsCPUAndGPU; + // config.computeUnits = MLComputeUnitsCPUAndGPU; //config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; config.computeUnits = MLComputeUnitsAll; diff --git a/cpp/ggml-alloc.c b/cpp/ggml-alloc.c index dd3b37d..ca19cbc 100644 --- a/cpp/ggml-alloc.c +++ b/cpp/ggml-alloc.c @@ -72,7 +72,7 @@ static void remove_allocated_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_te // check if a tensor is allocated by this buffer static bool wsp_ggml_tallocr_is_own(wsp_ggml_tallocr_t alloc, const struct wsp_ggml_tensor * tensor) { - return tensor->buffer == alloc->buffer; + return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer); } static bool wsp_ggml_is_view(struct wsp_ggml_tensor * t) { @@ -102,8 +102,6 @@ void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * t } } - AT_PRINTF("block %d\n", best_fit_block); - if (best_fit_block == -1) { // the last block is our last resort struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; @@ -117,6 +115,7 @@ void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * t return; } } + struct free_block * block = &alloc->free_blocks[best_fit_block]; void * addr = block->addr; block->addr = (char*)block->addr + size; @@ -129,6 +128,8 @@ void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * t } } + AT_PRINTF("block %d, addr %p\n", best_fit_block, addr); + tensor->data = addr; tensor->buffer = alloc->buffer; if (!alloc->measure) { @@ -229,6 +230,7 @@ void wsp_ggml_tallocr_reset(wsp_ggml_tallocr_t alloc) { alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows } else { alloc->free_blocks[0].size = wsp_ggml_backend_buffer_get_size(alloc->buffer) - align_offset; + wsp_ggml_backend_buffer_reset(alloc->buffer); } } @@ -263,9 +265,9 @@ wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure(size_t alignment) { return alloc; } -wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_backend * backend) { +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_buft(struct wsp_ggml_backend_buffer_type * buft) { // create a backend buffer to get the correct tensor allocation sizes - wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_buffer(backend, 1); + wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, 1); // TODO: move alloc initialization to a common wsp_ggml_tallocr_new_impl function wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer); @@ -275,13 +277,22 @@ wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_bac return alloc; } -wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) { - wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_buffer(backend, size); +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_backend * backend) { + return wsp_ggml_tallocr_new_measure_from_buft(wsp_ggml_backend_get_default_buffer_type(backend)); +} + +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buft(struct wsp_ggml_backend_buffer_type * buft, size_t size) { + // create a backend buffer to get the correct tensor allocation sizes + wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, size); wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer); alloc->buffer_owned = true; return alloc; } +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) { + return wsp_ggml_tallocr_new_from_buft(wsp_ggml_backend_get_default_buffer_type(backend), size); +} + wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer) { wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr)); @@ -449,11 +460,10 @@ static void init_view(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * view, if (update_backend) { view->backend = view->view_src->backend; } - view->buffer = view->view_src->buffer; + // views are initialized in the alloc buffer rather than the view_src buffer + view->buffer = alloc->buffer; view->data = (char *)view->view_src->data + view->view_offs; - // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend - // due to the wsp_ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras assert(wsp_ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft); if (!alloc->measure) { @@ -736,6 +746,10 @@ void wsp_ggml_allocr_set_parse_seq(wsp_ggml_allocr_t alloc, const int * list, in } void wsp_ggml_allocr_free(wsp_ggml_allocr_t alloc) { + if (alloc == NULL) { + return; + } + wsp_ggml_gallocr_free(alloc->galloc); wsp_ggml_tallocr_free(alloc->talloc); free(alloc); @@ -775,11 +789,22 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors_from_buft(struct ws } if (nbytes == 0) { - fprintf(stderr, "%s: no tensors to allocate\n", __func__); + // all the tensors in the context are already allocated +#ifndef NDEBUG + fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__); +#endif return NULL; } wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, nbytes); + if (buffer == NULL) { + // failed to allocate buffer +#ifndef NDEBUG + fprintf(stderr, "%s: failed to allocate buffer\n", __func__); +#endif + return NULL; + } + wsp_ggml_tallocr_t tallocr = wsp_ggml_tallocr_new_from_buffer(buffer); for (struct wsp_ggml_tensor * t = wsp_ggml_get_first_tensor(ctx); t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) { @@ -789,6 +814,11 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors_from_buft(struct ws } else { wsp_ggml_backend_view_init(buffer, t); } + } else { + if (t->view_src != NULL) { + // view of a pre-allocated tensor + wsp_ggml_backend_view_init(buffer, t); + } } } diff --git a/cpp/ggml-alloc.h b/cpp/ggml-alloc.h index 88a63bf..49d58be 100644 --- a/cpp/ggml-alloc.h +++ b/cpp/ggml-alloc.h @@ -52,8 +52,10 @@ typedef struct wsp_ggml_tallocr * wsp_ggml_tallocr_t; WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new(void * data, size_t size, size_t alignment); WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure(size_t alignment); -WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer); +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buft(struct wsp_ggml_backend_buffer_type * buft, size_t size); WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size); // allocates an owned buffer +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer); +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_buft(struct wsp_ggml_backend_buffer_type * buft); WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_backend * backend); WSP_GGML_API struct wsp_ggml_backend_buffer * wsp_ggml_tallocr_get_buffer(wsp_ggml_tallocr_t talloc); diff --git a/cpp/ggml-backend-impl.h b/cpp/ggml-backend-impl.h index 65edf15..83523e4 100644 --- a/cpp/ggml-backend-impl.h +++ b/cpp/ggml-backend-impl.h @@ -16,10 +16,14 @@ extern "C" { typedef void * wsp_ggml_backend_buffer_type_context_t; struct wsp_ggml_backend_buffer_type_i { - wsp_ggml_backend_buffer_t (*alloc_buffer) (wsp_ggml_backend_buffer_type_t buft, size_t size); - size_t (*get_alignment) (wsp_ggml_backend_buffer_type_t buft); // tensor alignment - size_t (*get_alloc_size) (wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor); // data size needed to allocate the tensor, including padding - bool (*supports_backend)(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend); // check if the buffer type is usable by the backend + const char * (*WSP_GGML_CALL get_name) (wsp_ggml_backend_buffer_type_t buft); + wsp_ggml_backend_buffer_t (*WSP_GGML_CALL alloc_buffer) (wsp_ggml_backend_buffer_type_t buft, size_t size); + size_t (*WSP_GGML_CALL get_alignment) (wsp_ggml_backend_buffer_type_t buft); // tensor alignment + size_t (*WSP_GGML_CALL get_alloc_size) (wsp_ggml_backend_buffer_type_t buft, const struct wsp_ggml_tensor * tensor); // data size needed to allocate the tensor, including padding + bool (*WSP_GGML_CALL supports_backend)(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend); // check if the buffer type is usable by the backend + // check if tensor data is in host memory + // should be equivalent to supports_backend(buft, wsp_ggml_backend_cpu_init()) + bool (*WSP_GGML_CALL is_host) (wsp_ggml_backend_buffer_type_t buft); }; struct wsp_ggml_backend_buffer_type { @@ -31,15 +35,15 @@ extern "C" { typedef void * wsp_ggml_backend_buffer_context_t; struct wsp_ggml_backend_buffer_i { - void (*free_buffer)(wsp_ggml_backend_buffer_t buffer); - //void (*reset) (wsp_ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras - void * (*get_base) (wsp_ggml_backend_buffer_t buffer); - void (*init_tensor)(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); - void (*set_tensor) (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor) (wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); - // (optional) copy tensor between different buffer-type, allow for single-copy tranfers - void (*cpy_tensor_from)(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); - void (*cpy_tensor_to) (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); + const char * (*WSP_GGML_CALL get_name) (wsp_ggml_backend_buffer_t buffer); + void (*WSP_GGML_CALL free_buffer)(wsp_ggml_backend_buffer_t buffer); + void * (*WSP_GGML_CALL get_base) (wsp_ggml_backend_buffer_t buffer); + void (*WSP_GGML_CALL init_tensor)(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); + void (*WSP_GGML_CALL set_tensor) (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*WSP_GGML_CALL get_tensor) (wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*WSP_GGML_CALL cpy_tensor) (wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); // dst is in the buffer, src may be in any buffer + void (*WSP_GGML_CALL clear) (wsp_ggml_backend_buffer_t buffer, uint8_t value); + void (*WSP_GGML_CALL reset) (wsp_ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras }; struct wsp_ggml_backend_buffer { @@ -47,14 +51,17 @@ extern "C" { wsp_ggml_backend_buffer_type_t buft; wsp_ggml_backend_buffer_context_t context; size_t size; + enum wsp_ggml_backend_buffer_usage usage; }; - wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init( + WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init( wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_backend_buffer_i iface, wsp_ggml_backend_buffer_context_t context, size_t size); + // do not use directly, use wsp_ggml_backend_tensor_copy instead + bool wsp_ggml_backend_buffer_copy_tensor(const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); // // Backend @@ -63,33 +70,31 @@ extern "C" { typedef void * wsp_ggml_backend_context_t; struct wsp_ggml_backend_i { - const char * (*get_name)(wsp_ggml_backend_t backend); + const char * (*WSP_GGML_CALL get_name)(wsp_ggml_backend_t backend); - void (*free)(wsp_ggml_backend_t backend); + void (*WSP_GGML_CALL free)(wsp_ggml_backend_t backend); // buffer allocation - wsp_ggml_backend_buffer_type_t (*get_default_buffer_type)(wsp_ggml_backend_t backend); + wsp_ggml_backend_buffer_type_t (*WSP_GGML_CALL get_default_buffer_type)(wsp_ggml_backend_t backend); - // (optional) asynchroneous tensor data access - void (*set_tensor_async)(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) asynchronous tensor data access + void (*WSP_GGML_CALL set_tensor_async)(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*WSP_GGML_CALL get_tensor_async)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*WSP_GGML_CALL cpy_tensor_async)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); - // (optional) asynchroneous tensor copy - void (*cpy_tensor_from_async)(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); - void (*cpy_tensor_to_async) (wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); - - void (*synchronize) (wsp_ggml_backend_t backend); + // (optional) complete all pending operations + void (*WSP_GGML_CALL synchronize)(wsp_ggml_backend_t backend); // compute graph with a plan - wsp_ggml_backend_graph_plan_t (*graph_plan_create) (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); - void (*graph_plan_free) (wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); - void (*graph_plan_compute)(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); + wsp_ggml_backend_graph_plan_t (*WSP_GGML_CALL graph_plan_create) (wsp_ggml_backend_t backend, const struct wsp_ggml_cgraph * cgraph); + void (*WSP_GGML_CALL graph_plan_free) (wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); + void (*WSP_GGML_CALL graph_plan_compute)(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); - // compute graph without a plan - void (*graph_compute)(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); + // compute graph without a plan (async) + bool (*WSP_GGML_CALL graph_compute)(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); // check if the backend supports an operation - bool (*supports_op)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op); + bool (*WSP_GGML_CALL supports_op)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op); }; struct wsp_ggml_backend { @@ -98,14 +103,13 @@ extern "C" { wsp_ggml_backend_context_t context; }; - // // Backend registry // - typedef wsp_ggml_backend_t (*wsp_ggml_backend_init_fn)(const char * params, void * user_data); + typedef wsp_ggml_backend_t (*WSP_GGML_CALL wsp_ggml_backend_init_fn)(const char * params, void * user_data); - void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data); + WSP_GGML_CALL void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data); #ifdef __cplusplus } diff --git a/cpp/ggml-backend.c b/cpp/ggml-backend.c index 75aeb99..dda7dbf 100644 --- a/cpp/ggml-backend.c +++ b/cpp/ggml-backend.c @@ -15,7 +15,11 @@ // backend buffer type -wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) { +const char * wsp_ggml_backend_buft_name(wsp_ggml_backend_buffer_type_t buft) { + return buft->iface.get_name(buft); +} + +WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) { return buft->iface.alloc_buffer(buft, size); } @@ -23,7 +27,7 @@ size_t wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type_t buft) return buft->iface.get_alignment(buft); } -size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor) { // get_alloc_size is optional, defaults to wsp_ggml_nbytes if (buft->iface.get_alloc_size) { return buft->iface.get_alloc_size(buft, tensor); @@ -35,9 +39,16 @@ bool wsp_ggml_backend_buft_supports_backend(wsp_ggml_backend_buffer_type_t buft, return buft->iface.supports_backend(buft, backend); } +bool wsp_ggml_backend_buft_is_host(wsp_ggml_backend_buffer_type_t buft) { + if (buft->iface.is_host) { + return buft->iface.is_host(buft); + } + return false; +} + // backend buffer -wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init( +WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init( wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_backend_buffer_i iface, wsp_ggml_backend_buffer_context_t context, @@ -51,11 +62,16 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init( /* .buft = */ buft, /* .context = */ context, /* .size = */ size, + /* .usage = */ WSP_GGML_BACKEND_BUFFER_USAGE_ANY }; return buffer; } +const char * wsp_ggml_backend_buffer_name(wsp_ggml_backend_buffer_t buffer) { + return buffer->iface.get_name(buffer); +} + void wsp_ggml_backend_buffer_free(wsp_ggml_backend_buffer_t buffer) { if (buffer == NULL) { return; @@ -79,7 +95,7 @@ void * wsp_ggml_backend_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { return base; } -void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) { // init_tensor is optional if (buffer->iface.init_tensor) { buffer->iface.init_tensor(buffer, tensor); @@ -87,17 +103,43 @@ void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struc } size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer) { - return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type(buffer)); + return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_get_type(buffer)); } size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) { - return wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type(buffer), tensor); + return wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_get_type(buffer), tensor); +} + +void wsp_ggml_backend_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) { + buffer->iface.clear(buffer, value); } -wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_type(wsp_ggml_backend_buffer_t buffer) { +bool wsp_ggml_backend_buffer_is_host(wsp_ggml_backend_buffer_t buffer) { + return wsp_ggml_backend_buft_is_host(wsp_ggml_backend_buffer_get_type(buffer)); +} + +void wsp_ggml_backend_buffer_set_usage(wsp_ggml_backend_buffer_t buffer, enum wsp_ggml_backend_buffer_usage usage) { + buffer->usage = usage; +} + +wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_get_type(wsp_ggml_backend_buffer_t buffer) { return buffer->buft; } +void wsp_ggml_backend_buffer_reset(wsp_ggml_backend_buffer_t buffer) { + if (buffer->iface.reset) { + buffer->iface.reset(buffer); + } +} + +bool wsp_ggml_backend_buffer_copy_tensor(const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + wsp_ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer; + if (dst_buf->iface.cpy_tensor) { + return src->buffer->iface.cpy_tensor(dst_buf, src, dst); + } + return false; +} + // backend const char * wsp_ggml_backend_name(wsp_ggml_backend_t backend) { @@ -131,30 +173,42 @@ void wsp_ggml_backend_tensor_set_async(wsp_ggml_backend_t backend, struct wsp_gg WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds"); - backend->iface.set_tensor_async(backend, tensor, data, offset, size); + if (backend->iface.set_tensor_async == NULL) { + wsp_ggml_backend_tensor_set(tensor, data, offset, size); + } else { + backend->iface.set_tensor_async(backend, tensor, data, offset, size); + } } void wsp_ggml_backend_tensor_get_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds"); - backend->iface.get_tensor_async(backend, tensor, data, offset, size); + if (backend->iface.get_tensor_async == NULL) { + wsp_ggml_backend_tensor_get(tensor, data, offset, size); + } else { + backend->iface.get_tensor_async(backend, tensor, data, offset, size); + } } -void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +WSP_GGML_CALL void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set"); + WSP_GGML_ASSERT(buf != NULL && "tensor buffer not set"); WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds"); - tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size); + tensor->buffer->iface.set_tensor(buf, tensor, data, offset, size); } -void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { +WSP_GGML_CALL void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set"); WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds"); - tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size); + tensor->buffer->iface.get_tensor(buf, tensor, data, offset, size); } void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend) { @@ -175,16 +229,10 @@ void wsp_ggml_backend_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backe void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { backend->iface.graph_plan_compute(backend, plan); - - // TODO: optional sync - wsp_ggml_backend_synchronize(backend); } -void wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { - backend->iface.graph_compute(backend, cgraph); - - // TODO: optional sync - wsp_ggml_backend_synchronize(backend); +bool wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { + return backend->iface.graph_compute(backend, cgraph); } bool wsp_ggml_backend_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { @@ -209,28 +257,20 @@ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const str } void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { - //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); - //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); - // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, wsp_ggml_backend_name(src->backend), wsp_ggml_backend_name(dst->backend), wsp_ggml_nbytes(src)); - if (src == dst) { return; } - // TODO: allow backends to support copy to/from same backend - - if (dst->buffer->iface.cpy_tensor_from != NULL) { - dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst); - } else if (src->buffer->iface.cpy_tensor_to != NULL) { - src->buffer->iface.cpy_tensor_to(src->buffer, src, dst); - } else { - // shouldn't be hit when copying from/to CPU - #ifndef NDEBUG - fprintf(stderr, "wsp_ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to " - "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name); - #endif + if (wsp_ggml_backend_buffer_is_host(src->buffer)) { + wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src)); + } else if (wsp_ggml_backend_buffer_is_host(dst->buffer)) { + wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src)); + } else if (!wsp_ggml_backend_buffer_copy_tensor(src, dst)) { +#ifndef NDEBUG + fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, wsp_ggml_backend_buffer_name(src->buffer), wsp_ggml_backend_buffer_name(dst->buffer)); +#endif size_t nbytes = wsp_ggml_nbytes(src); void * data = malloc(nbytes); wsp_ggml_backend_tensor_get(src, data, 0, nbytes); @@ -239,6 +279,31 @@ void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_ } } +void wsp_ggml_backend_tensor_copy_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); + + if (src == dst) { + return; + } + + if (wsp_ggml_backend_buft_supports_backend(src->buffer->buft, backend) && wsp_ggml_backend_buft_supports_backend(dst->buffer->buft, backend)) { + if (backend->iface.cpy_tensor_async != NULL) { + if (backend->iface.cpy_tensor_async(backend, src, dst)) { + return; + } + } + } + + size_t nbytes = wsp_ggml_nbytes(src); + if (wsp_ggml_backend_buffer_is_host(src->buffer)) { + wsp_ggml_backend_tensor_set_async(backend, dst, src->data, 0, nbytes); + } + else { + wsp_ggml_backend_tensor_copy(src, dst); + } +} + + // backend registry #define WSP_GGML_MAX_BACKENDS_REG 16 @@ -253,9 +318,9 @@ struct wsp_ggml_backend_reg { static struct wsp_ggml_backend_reg wsp_ggml_backend_registry[WSP_GGML_MAX_BACKENDS_REG]; static size_t wsp_ggml_backend_registry_count = 0; -static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data); +WSP_GGML_CALL static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data); -static void wsp_ggml_backend_registry_init(void) { +WSP_GGML_CALL static void wsp_ggml_backend_registry_init(void) { static bool initialized = false; if (initialized) { @@ -268,21 +333,21 @@ static void wsp_ggml_backend_registry_init(void) { // add forward decls here to avoid including the backend headers #ifdef WSP_GGML_USE_CUBLAS - extern void wsp_ggml_backend_cuda_reg_devices(void); + extern WSP_GGML_CALL void wsp_ggml_backend_cuda_reg_devices(void); wsp_ggml_backend_cuda_reg_devices(); #endif #ifdef WSP_GGML_USE_METAL - extern wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); - extern wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void); + extern WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); + extern WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void); wsp_ggml_backend_register("Metal", wsp_ggml_backend_reg_metal_init, wsp_ggml_backend_metal_buffer_type(), NULL); #endif } -void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data) { +WSP_GGML_CALL void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data) { WSP_GGML_ASSERT(wsp_ggml_backend_registry_count < WSP_GGML_MAX_BACKENDS_REG); - int id = wsp_ggml_backend_registry_count; + size_t id = wsp_ggml_backend_registry_count; wsp_ggml_backend_registry[id] = (struct wsp_ggml_backend_reg) { /* .name = */ {0}, @@ -315,6 +380,8 @@ size_t wsp_ggml_backend_reg_find_by_name(const char * name) { return i; } } + + // not found return SIZE_MAX; } @@ -325,15 +392,15 @@ wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend_from_str(const char * backe const char * params = strchr(backend_str, ':'); char backend_name[128]; if (params == NULL) { - strcpy(backend_name, backend_str); + snprintf(backend_name, sizeof(backend_name), "%s", backend_str); params = ""; } else { - strncpy(backend_name, backend_str, params - backend_str); - backend_name[params - backend_str] = '\0'; + snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str); params++; } size_t backend_i = wsp_ggml_backend_reg_find_by_name(backend_name); + if (backend_i == SIZE_MAX) { fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name); return NULL; @@ -372,69 +439,80 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_reg_alloc_buffer(size_t i, size_t siz // backend CPU -static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { +WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_buffer_name(wsp_ggml_backend_buffer_t buffer) { + return "CPU"; + + WSP_GGML_UNUSED(buffer); +} + +WSP_GGML_CALL static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { return (void *)buffer->context; } -static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { +WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { free(buffer->context); - WSP_GGML_UNUSED(buffer); } -static void wsp_ggml_backend_cpu_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds"); - WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - +WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); WSP_GGML_UNUSED(buffer); } -static void wsp_ggml_backend_cpu_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { - WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds"); - WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - +WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { memcpy(data, (const char *)tensor->data + offset, size); WSP_GGML_UNUSED(buffer); } -static void wsp_ggml_backend_cpu_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { - wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src)); +WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + if (wsp_ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, wsp_ggml_nbytes(src)); + return true; + } + return false; WSP_GGML_UNUSED(buffer); } -static void wsp_ggml_backend_cpu_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { - wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src)); - - WSP_GGML_UNUSED(buffer); +WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); } static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i = { + /* .get_name = */ wsp_ggml_backend_cpu_buffer_name, /* .free_buffer = */ wsp_ggml_backend_cpu_buffer_free_buffer, /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base, /* .init_tensor = */ NULL, // no initialization required /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor, - /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from, - /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to, + /* .cpy_tensor = */ wsp_ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ wsp_ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, }; // for buffers from ptr, free is not called static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { + /* .get_name = */ wsp_ggml_backend_cpu_buffer_name, /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base, /* .init_tensor = */ NULL, // no initialization required /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor, - /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_from, - /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_buffer_cpy_tensor_to, + /* .cpy_tensor = */ wsp_ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ wsp_ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, }; static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 -static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) { +WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) { + return "CPU"; + + WSP_GGML_UNUSED(buft); +} + +WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) { size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned void * data = malloc(size); // TODO: maybe use WSP_GGML_ALIGNED_MALLOC? @@ -443,31 +521,95 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(w return wsp_ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size); } -static size_t wsp_ggml_backend_cpu_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) { +WSP_GGML_CALL static size_t wsp_ggml_backend_cpu_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) { return TENSOR_ALIGNMENT; WSP_GGML_UNUSED(buft); } -static bool wsp_ggml_backend_cpu_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) { +WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) { return wsp_ggml_backend_is_cpu(backend); WSP_GGML_UNUSED(buft); } -wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void) { - static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_cpu = { +WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) { + return true; + + WSP_GGML_UNUSED(buft); +} + +WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void) { + static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_cpu_buffer_type = { /* .iface = */ { + /* .get_name = */ wsp_ggml_backend_cpu_buffer_type_get_name, /* .alloc_buffer = */ wsp_ggml_backend_cpu_buffer_type_alloc_buffer, /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment, /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ wsp_ggml_backend_cpu_buffer_type_is_host, }, /* .context = */ NULL, }; - return &wsp_ggml_backend_buffer_type_cpu; + return &wsp_ggml_backend_cpu_buffer_type; +} + +#ifdef WSP_GGML_USE_CPU_HBM + +// buffer type HBM + +#include + +WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_hbm_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) { + return "CPU_HBM"; + + WSP_GGML_UNUSED(buft); +} + +WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_hbm_buffer_get_name(wsp_ggml_backend_buffer_t buf) { + return "CPU_HBM"; + + WSP_GGML_UNUSED(buf); +} + +WSP_GGML_CALL static void wsp_ggml_backend_cpu_hbm_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { + hbw_free(buffer->context); +} + +WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) { + //void * ptr = hbw_malloc(size); + void * ptr; + int result = hbw_posix_memalign(&ptr, wsp_ggml_backend_cpu_buffer_type_get_alignment(buft), size); + if (result != 0) { + fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size); + return NULL; + } + + wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.get_name = wsp_ggml_backend_cpu_hbm_buffer_get_name; + buffer->iface.free_buffer = wsp_ggml_backend_cpu_hbm_buffer_free_buffer; + + return buffer; +} + +wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_hbm_buffer_type(void) { + static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_cpu_buffer_type_hbm = { + /* .iface = */ { + /* .get_name = */ wsp_ggml_backend_cpu_hbm_buffer_type_get_name, + /* .alloc_buffer = */ wsp_ggml_backend_cpu_hbm_buffer_type_alloc_buffer, + /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment, + /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes + /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ wsp_ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; + + return &wsp_ggml_backend_cpu_buffer_type_hbm; } +#endif struct wsp_ggml_backend_cpu_context { int n_threads; @@ -475,20 +617,20 @@ struct wsp_ggml_backend_cpu_context { size_t work_size; }; -static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) { +WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) { return "CPU"; WSP_GGML_UNUSED(backend); } -static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) { +WSP_GGML_CALL static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) { struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context; free(cpu_ctx->work_data); free(cpu_ctx); free(backend); } -static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_get_default_buffer_type(wsp_ggml_backend_t backend) { +WSP_GGML_CALL static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_get_default_buffer_type(wsp_ggml_backend_t backend) { return wsp_ggml_backend_cpu_buffer_type(); WSP_GGML_UNUSED(backend); @@ -499,13 +641,13 @@ struct wsp_ggml_backend_plan_cpu { struct wsp_ggml_cgraph cgraph; }; -static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { +WSP_GGML_CALL static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ggml_backend_t backend, const struct wsp_ggml_cgraph * cgraph) { struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context; struct wsp_ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct wsp_ggml_backend_plan_cpu)); cpu_plan->cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads); - cpu_plan->cgraph = *cgraph; + cpu_plan->cgraph = *cgraph; // FIXME: deep copy if (cpu_plan->cplan.work_size > 0) { cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); @@ -514,7 +656,7 @@ static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ return cpu_plan; } -static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { +WSP_GGML_CALL static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan; free(cpu_plan->cplan.work_data); @@ -523,7 +665,7 @@ static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp WSP_GGML_UNUSED(backend); } -static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { +WSP_GGML_CALL static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan; wsp_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); @@ -531,7 +673,7 @@ static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, WSP_GGML_UNUSED(backend); } -static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { +WSP_GGML_CALL static bool wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context; struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads); @@ -545,13 +687,20 @@ static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struc cplan.work_data = cpu_ctx->work_data; wsp_ggml_graph_compute(cgraph, &cplan); + return true; } -static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { - return true; +WSP_GGML_CALL static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { + switch (op->op) { + case WSP_GGML_OP_CPY: + return op->type != WSP_GGML_TYPE_IQ2_XXS && op->type != WSP_GGML_TYPE_IQ2_XS; // missing type_traits.from_float + case WSP_GGML_OP_MUL_MAT: + return op->src[1]->type == WSP_GGML_TYPE_F32 || op->src[1]->type == wsp_ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + default: + return true; + } WSP_GGML_UNUSED(backend); - WSP_GGML_UNUSED(op); } static struct wsp_ggml_backend_i cpu_backend_i = { @@ -560,8 +709,7 @@ static struct wsp_ggml_backend_i cpu_backend_i = { /* .get_default_buffer_type = */ wsp_ggml_backend_cpu_get_default_buffer_type, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .cpy_tensor_from_async = */ NULL, - /* .cpy_tensor_to_async = */ NULL, + /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ wsp_ggml_backend_cpu_graph_plan_create, /* .graph_plan_free = */ wsp_ggml_backend_cpu_graph_plan_free, @@ -586,8 +734,8 @@ wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void) { return cpu_backend; } -bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend) { - return backend->iface.get_name == wsp_ggml_backend_cpu_name; +WSP_GGML_CALL bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend) { + return backend && backend->iface.get_name == wsp_ggml_backend_cpu_name; } void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads) { @@ -597,11 +745,11 @@ void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_th ctx->n_threads = n_threads; } -wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { +WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { return wsp_ggml_backend_buffer_init(wsp_ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); } -static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data) { +WSP_GGML_CALL static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data) { return wsp_ggml_backend_cpu_init(); WSP_GGML_UNUSED(params); @@ -611,7 +759,7 @@ static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, voi // scheduler -#define WSP_GGML_MAX_BACKENDS 4 +#define WSP_GGML_MAX_BACKENDS 16 #define WSP_GGML_MAX_SPLITS 256 #define WSP_GGML_MAX_SPLIT_INPUTS 16 @@ -621,21 +769,29 @@ struct wsp_ggml_backend_sched_split { int i_end; struct wsp_ggml_tensor * inputs[WSP_GGML_MAX_SPLIT_INPUTS]; int n_inputs; + // graph view of this split struct wsp_ggml_cgraph graph; }; struct wsp_ggml_backend_sched { + bool is_reset; // true if the scheduler has been reset since the last graph split + int n_backends; wsp_ggml_backend_t backends[WSP_GGML_MAX_BACKENDS]; + wsp_ggml_backend_buffer_type_t bufts[WSP_GGML_MAX_BACKENDS]; wsp_ggml_tallocr_t tallocs[WSP_GGML_MAX_BACKENDS]; wsp_ggml_gallocr_t galloc; + // hash keys of the nodes in the graph struct wsp_ggml_hash_set hash_set; - wsp_ggml_tallocr_t * node_talloc; // [hash_set.size] - struct wsp_ggml_tensor * (* node_copies)[WSP_GGML_MAX_BACKENDS]; // [hash_set.size][WSP_GGML_MAX_BACKENDS] + // hash values (arrays of [hash_set.size]) + wsp_ggml_tallocr_t * node_talloc; // tallocr assigned to each node (indirectly this is the backend) + struct wsp_ggml_tensor * (* node_copies)[WSP_GGML_MAX_BACKENDS]; // copies of each node for each destination backend + // copy of the graph with modified inputs struct wsp_ggml_cgraph * graph; + struct wsp_ggml_backend_sched_split splits[WSP_GGML_MAX_SPLITS]; int n_splits; @@ -648,6 +804,9 @@ struct wsp_ggml_backend_sched { __attribute__((aligned(WSP_GGML_MEM_ALIGN))) #endif char context_buffer[WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS*sizeof(struct wsp_ggml_tensor) + sizeof(struct wsp_ggml_cgraph)]; + + wsp_ggml_backend_sched_eval_callback callback_eval; + void * callback_eval_user_data; }; #define hash_id(node) wsp_ggml_hash_find_or_insert(sched->hash_set, node) @@ -676,14 +835,22 @@ static int sched_allocr_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t return INT_MAX; } -static wsp_ggml_backend_t get_buffer_backend(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_buffer_t buffer) { +static wsp_ggml_tallocr_t sched_allocr_from_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_buffer_t buffer) { if (buffer == NULL) { return NULL; } + + // check if this is already allocate in a allocr buffer (from user manual allocations) + for (int i = 0; i < sched->n_backends; i++) { + if (wsp_ggml_tallocr_get_buffer(sched->tallocs[i]) == buffer) { + return sched->tallocs[i]; + } + } + // find highest prio backend that supports the buffer type for (int i = 0; i < sched->n_backends; i++) { if (wsp_ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) { - return sched->backends[i]; + return sched->tallocs[i]; } } WSP_GGML_ASSERT(false && "tensor buffer type not supported by any backend"); @@ -693,7 +860,6 @@ static wsp_ggml_backend_t get_allocr_backend(wsp_ggml_backend_sched_t sched, wsp if (allocr == NULL) { return NULL; } - // find highest prio backend that supports the buffer type for (int i = 0; i < sched->n_backends; i++) { if (sched->tallocs[i] == allocr) { return sched->backends[i]; @@ -703,7 +869,7 @@ static wsp_ggml_backend_t get_allocr_backend(wsp_ggml_backend_sched_t sched, wsp } #if 0 -static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*8 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug, remove +static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*16 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug only #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) #define GET_CAUSE(node) causes[hash_id(node)] #else @@ -712,45 +878,37 @@ static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*8 + WSP_GGML_MAX_SPLITS*WSP_GGML_ #endif // returns the backend that should be used for the node based on the current locations -static wsp_ggml_backend_t sched_backend_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) { - // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there - // ie. kv cache updates - // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend. +static wsp_ggml_tallocr_t sched_allocr_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) { + // assign pre-allocated nodes to their backend // dst - wsp_ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer); - if (cur_backend != NULL) { + wsp_ggml_tallocr_t cur_allocr = sched_allocr_from_buffer(sched, node->buffer); + if (cur_allocr != NULL) { SET_CAUSE(node, "1.dst"); - return cur_backend; + return cur_allocr; } - // view_src - if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) { - SET_CAUSE(node, "1.vsrc"); - return get_buffer_backend(sched, node->view_src->buffer); + if (node->view_src != NULL) { + cur_allocr = sched_allocr_from_buffer(sched, node->view_src->buffer); + if (cur_allocr != NULL) { + SET_CAUSE(node, "1.vsrc"); + return cur_allocr; + } } - - // src - int cur_prio = INT_MAX; - size_t cur_size = 0; - + // assign nodes that use weights to the backend of the weights for (int i = 0; i < WSP_GGML_MAX_SRC; i++) { const struct wsp_ggml_tensor * src = node->src[i]; if (src == NULL) { break; } - wsp_ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer); - if (src_backend != NULL) { - int src_prio = sched_backend_prio(sched, src_backend); - size_t src_size = wsp_ggml_nbytes(src); - if (src_prio < cur_prio && src_size >= cur_size) { - cur_prio = src_prio; - cur_size = src_size; - cur_backend = src_backend; - SET_CAUSE(node, "1.src%d", i); - } + if (src->buffer != NULL && src->buffer->usage == WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + wsp_ggml_tallocr_t src_allocr = sched_allocr_from_buffer(sched, src->buffer); + // operations with weights are always run on the same backend as the weights + SET_CAUSE(node, "1.wgt%d", i); + return src_allocr; } } - return cur_backend; + + return NULL; } static char * fmt_size(size_t size) { @@ -783,7 +941,7 @@ static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_g } wsp_ggml_tallocr_t node_allocr = node_allocr(node); wsp_ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME: - fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name, + fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name, fmt_size(wsp_ggml_nbytes(node)), node_allocr ? wsp_ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node)); for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { struct wsp_ggml_tensor * src = node->src[j]; @@ -792,7 +950,7 @@ static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_g } wsp_ggml_tallocr_t src_allocr = node_allocr(src); wsp_ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL; - fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, + fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name, fmt_size(wsp_ggml_nbytes(src)), src_backend ? wsp_ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); } fprintf(stderr, "\n"); @@ -808,15 +966,17 @@ static struct wsp_ggml_tensor * wsp_ggml_dup_tensor_layout(struct wsp_ggml_conte return dup; } + +//#define DEBUG_PASS1 +//#define DEBUG_PASS2 +//#define DEBUG_PASS3 +//#define DEBUG_PASS4 + // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend -// TODO: merge passes static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) { - // reset state - size_t hash_size = sched->hash_set.size; - memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); - memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size); - memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size); + // reset splits sched->n_splits = 0; + sched->is_reset = false; struct wsp_ggml_init_params params = { /* .mem_size = */ sizeof(sched->context_buffer), @@ -824,26 +984,22 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg /* .no_alloc = */ true }; - if (sched->ctx != NULL) { - wsp_ggml_free(sched->ctx); - } + wsp_ggml_free(sched->ctx); sched->ctx = wsp_ggml_init(params); + if (sched->ctx == NULL) { + fprintf(stderr, "%s: failed to initialize context\n", __func__); + WSP_GGML_ASSERT(false); + } - // pass 1: assign backends to ops with allocated inputs + // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct wsp_ggml_tensor * leaf = graph->leafs[i]; if (node_allocr(leaf) != NULL) { // do not overwrite user assignments continue; } - wsp_ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer); - if (leaf_backend == NULL && leaf->view_src != NULL) { - leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer); - } - if (leaf_backend != NULL) { - node_allocr(leaf) = wsp_ggml_backend_sched_get_tallocr(sched, leaf_backend); - } + node_allocr(leaf) = sched_allocr_from_cur(sched, leaf); } for (int i = 0; i < graph->n_nodes; i++) { @@ -852,50 +1008,120 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg // do not overwrite user assignments continue; } - wsp_ggml_backend_t node_backend = sched_backend_from_cur(sched, node); - if (node_backend != NULL) { - node_allocr(node) = wsp_ggml_backend_sched_get_tallocr(sched, node_backend); + node_allocr(node) = sched_allocr_from_cur(sched, node); + // src + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + if (node_allocr(src) == NULL) { + node_allocr(src) = sched_allocr_from_cur(sched, src); + } } } - //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#ifdef DEBUG_PASS1 + fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#endif - // pass 2: assign backends to ops from current assignments - // TODO: - // - reuse sched_backend_from_cur - for (int i = 0; i < graph->n_nodes; i++) { - struct wsp_ggml_tensor * node = graph->nodes[i]; - wsp_ggml_tallocr_t node_allocr = node_allocr(node); - if (node_allocr == NULL) { - int cur_prio = INT_MAX; - size_t cur_size = 0; - for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { - struct wsp_ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; + // pass 2: expand current backend assignments + // assign the same backend to adjacent nodes + // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend) + // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops + + // pass 2.1 expand gpu up + { + wsp_ggml_tallocr_t cur_allocr = NULL; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (wsp_ggml_is_view_op(node->op)) { + continue; + } + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + if (node_allocr != NULL) { + if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_allocr = NULL; + } else { + cur_allocr = node_allocr; } - wsp_ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != NULL) { - int src_prio = sched_allocr_prio(sched, src_allocr); - size_t src_size = wsp_ggml_nbytes(src); - if (src_prio < cur_prio && src_size >= cur_size) { - cur_prio = src_prio; - cur_size = src_size; - node_allocr = src_allocr; - SET_CAUSE(node, "2.src%d", j); - } + } else { + node_allocr(node) = cur_allocr; + SET_CAUSE(node, "2.1"); + } + } + } + + // pass 2.2 expand gpu down + { + wsp_ggml_tallocr_t cur_allocr = NULL; + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (wsp_ggml_is_view_op(node->op)) { + continue; + } + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + if (node_allocr != NULL) { + if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_allocr = NULL; + } else { + cur_allocr = node_allocr; } + } else { + node_allocr(node) = cur_allocr; + SET_CAUSE(node, "2.2"); + } + } + } + + // pass 2.3 expand rest up + { + wsp_ggml_tallocr_t cur_allocr = NULL; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (wsp_ggml_is_view_op(node->op)) { + continue; + } + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + if (node_allocr != NULL) { + cur_allocr = node_allocr; + } else { + node_allocr(node) = cur_allocr; + SET_CAUSE(node, "2.3"); } + } + } + + // pass 2.4 expand rest down + { + wsp_ggml_tallocr_t cur_allocr = NULL; + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (wsp_ggml_is_view_op(node->op)) { + continue; + } + wsp_ggml_tallocr_t node_allocr = node_allocr(node); if (node_allocr != NULL) { - node_allocr(node) = node_allocr; + cur_allocr = node_allocr; + } else { + node_allocr(node) = cur_allocr; + SET_CAUSE(node, "2.4"); } } } - //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#ifdef DEBUG_PASS2 + fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#endif - // pass 3: assign backends to remaining src from dst (should only be leafs) + // pass 3: assign backends to remaining src from dst and view_src for (int i = 0; i < graph->n_nodes; i++) { struct wsp_ggml_tensor * node = graph->nodes[i]; - wsp_ggml_tallocr_t node_allocr = node_allocr(node); + wsp_ggml_tallocr_t cur_allocr = node_allocr(node); + if (node->view_src != NULL && cur_allocr == NULL) { + cur_allocr = node_allocr(node) = node_allocr(node->view_src); + SET_CAUSE(node, "3.vsrc"); + } for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { struct wsp_ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -903,81 +1129,107 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg } wsp_ggml_tallocr_t src_allocr = node_allocr(src); if (src_allocr == NULL) { - node_allocr(src) = node_allocr; + if (src->view_src != NULL) { + // views are always on the same backend as the source + node_allocr(src) = node_allocr(src->view_src); + SET_CAUSE(src, "3.vsrc"); + } else { + node_allocr(src) = cur_allocr; + SET_CAUSE(src, "3.cur"); + } } } } - //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#ifdef DEBUG_PASS3 + fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#endif // pass 4: split graph, find tensors that need to be copied - // TODO: - // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost - // find first backend - int cur_split = 0; - for (int i = 0; i < graph->n_nodes; i++) { - struct wsp_ggml_tensor * node = graph->nodes[i]; - if (node->view_src == NULL) { - sched->splits[0].tallocr = node_allocr(node); - break; - } - } - sched->splits[0].i_start = 0; - sched->splits[0].n_inputs = 0; - memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK - wsp_ggml_tallocr_t cur_allocr = sched->splits[0].tallocr; - size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr); - for (int i = 0; i < graph->n_nodes; i++) { - struct wsp_ggml_tensor * node = graph->nodes[i]; - - if (wsp_ggml_is_view_op(node->op)) { - continue; + { + int cur_split = 0; + // find the backend of the first split, skipping view ops + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (!wsp_ggml_is_view_op(node->op)) { + sched->splits[0].tallocr = node_allocr(node); + break; + } } + sched->splits[0].i_start = 0; + sched->splits[0].n_inputs = 0; + memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK + wsp_ggml_tallocr_t cur_allocr = sched->splits[0].tallocr; + size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr); + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + + if (wsp_ggml_is_view_op(node->op)) { + continue; + } - wsp_ggml_tallocr_t node_allocr = node_allocr(node); + wsp_ggml_tallocr_t node_allocr = node_allocr(node); - if (node_allocr != cur_allocr) { - sched->splits[cur_split].i_end = i; - cur_split++; - WSP_GGML_ASSERT(cur_split < WSP_GGML_MAX_SPLITS); - sched->splits[cur_split].tallocr = node_allocr; - sched->splits[cur_split].i_start = i; - sched->splits[cur_split].n_inputs = 0; - memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK - cur_allocr = node_allocr; - cur_backend_id = sched_allocr_prio(sched, cur_allocr); - } + WSP_GGML_ASSERT(node_allocr != NULL); // all nodes should be assigned by now - // find inputs that are not on the same backend - for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { - struct wsp_ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; + if (node_allocr != cur_allocr) { + sched->splits[cur_split].i_end = i; + cur_split++; + WSP_GGML_ASSERT(cur_split < WSP_GGML_MAX_SPLITS); + sched->splits[cur_split].tallocr = node_allocr; + sched->splits[cur_split].i_start = i; + sched->splits[cur_split].n_inputs = 0; + cur_allocr = node_allocr; + cur_backend_id = sched_allocr_prio(sched, cur_allocr); } - wsp_ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != node_allocr) { - int n_inputs = sched->splits[cur_split].n_inputs++; - WSP_GGML_ASSERT(n_inputs < WSP_GGML_MAX_SPLIT_INPUTS); - sched->splits[cur_split].inputs[n_inputs] = (struct wsp_ggml_tensor *)src; - - // create copies - size_t id = hash_id(src); - if (sched->node_copies[id][cur_backend_id] == NULL) { - struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src); - sched->node_copies[id][cur_backend_id] = tensor_copy; - node_allocr(tensor_copy) = cur_allocr; - wsp_ggml_backend_t backend = get_allocr_backend(sched, cur_allocr); - wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name); + + // find inputs that are not on the same backend + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + wsp_ggml_tallocr_t src_allocr = node_allocr(src); + WSP_GGML_ASSERT(src_allocr != NULL); // all inputs should be assigned by now + if (src_allocr != node_allocr) { + // check if the input is already in the split + bool found = false; + for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) { + if (sched->splits[cur_split].inputs[k] == src) { + found = true; + break; + } + } + + if (!found) { + int n_inputs = sched->splits[cur_split].n_inputs++; + //printf("split %d input %d: %s (%s)\n", cur_split, n_inputs, src->name, wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr))); + WSP_GGML_ASSERT(n_inputs < WSP_GGML_MAX_SPLIT_INPUTS); + sched->splits[cur_split].inputs[n_inputs] = src; + } + + // create a copy of the input in the split's backend + size_t id = hash_id(src); + if (sched->node_copies[id][cur_backend_id] == NULL) { + wsp_ggml_backend_t backend = get_allocr_backend(sched, cur_allocr); + struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src); + wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name); + + sched->node_copies[id][cur_backend_id] = tensor_copy; + node_allocr(tensor_copy) = cur_allocr; + SET_CAUSE(tensor_copy, "4.cpy"); + } + node->src[j] = sched->node_copies[id][cur_backend_id]; } - node->src[j] = sched->node_copies[id][cur_backend_id]; } } + sched->splits[cur_split].i_end = graph->n_nodes; + sched->n_splits = cur_split + 1; } - sched->splits[cur_split].i_end = graph->n_nodes; - sched->n_splits = cur_split + 1; - - //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout); +#ifdef DEBUG_PASS4 + fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#endif -#if 1 +#ifndef NDEBUG // sanity check: all sources should have the same backend as the node for (int i = 0; i < graph->n_nodes; i++) { struct wsp_ggml_tensor * node = graph->nodes[i]; @@ -985,6 +1237,11 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg if (node_allocr == NULL) { fprintf(stderr, "!!!!!!! %s has no backend\n", node->name); } + if (node->view_src != NULL && node_allocr != node_allocr(node->view_src)) { + fprintf(stderr, "!!!!!!! %s has backend %s, view_src %s has backend %s\n", + node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL", + node->view_src->name, node_allocr(node->view_src) ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr(node->view_src))) : "NULL"); + } for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { struct wsp_ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -996,8 +1253,14 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL", j, src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL"); } + if (src->view_src != NULL && src_allocr != node_allocr(src->view_src)) { + fprintf(stderr, "!!!!!!! [src] %s has backend %s, view_src %s has backend %s\n", + src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL", + src->view_src->name, node_allocr(src->view_src) ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr(src->view_src))) : "NULL"); + } } } + fflush(stderr); #endif // create copies of the graph for each split @@ -1011,6 +1274,8 @@ static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cg for (int j = 0; j < split->n_inputs; j++) { struct wsp_ggml_tensor * input = split->inputs[j]; struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)]; + // add a dependency to the input source so that it is not freed before the copy is done + WSP_GGML_ASSERT(input_cpy->src[0] == NULL || input_cpy->src[0] == input); input_cpy->src[0] = input; graph_copy->nodes[graph_copy->n_nodes++] = input_cpy; } @@ -1045,24 +1310,16 @@ static void sched_compute_splits(wsp_ggml_backend_sched_t sched) { uint64_t copy_start_us = wsp_ggml_time_us(); for (int j = 0; j < split->n_inputs; j++) { struct wsp_ggml_tensor * input = split->inputs[j]; - struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)]; - if (input->buffer == NULL) { - if (input->view_src == NULL) { - fprintf(stderr, "input %s has no buffer and no view_src\n", input->name); - exit(1); - } - // FIXME: may need to use the sched buffer instead - wsp_ggml_backend_view_init(input->view_src->buffer, input); - } - if (input_cpy->buffer == NULL) { - fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name); - exit(1); - } - //WSP_GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend); - //WSP_GGML_ASSERT(input_cpy->buffer->backend == split_backend); - wsp_ggml_backend_tensor_copy(input, input_cpy); + struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][split_backend_id]; + + WSP_GGML_ASSERT(input->buffer != NULL); + WSP_GGML_ASSERT(input_cpy->buffer != NULL); + + // TODO: avoid this copy if it was already copied in a previous split, and the input didn't change + // this is important to avoid copying constants such as KQ_mask and inp_pos multiple times + wsp_ggml_backend_tensor_copy_async(split_backend, input, input_cpy); } - // wsp_ggml_backend_synchronize(split_backend); + //wsp_ggml_backend_synchronize(split_backend); // necessary to measure copy time int64_t copy_end_us = wsp_ggml_time_us(); copy_us[split_backend_id] += copy_end_us - copy_start_us; @@ -1072,9 +1329,38 @@ static void sched_compute_splits(wsp_ggml_backend_sched_t sched) { wsp_ggml_graph_dump_dot(split->graph, NULL, split_filename); #endif + uint64_t compute_start_us = wsp_ggml_time_us(); - wsp_ggml_backend_graph_compute(split_backend, &split->graph); - // wsp_ggml_backend_synchronize(split_backend); + if (!sched->callback_eval) { + wsp_ggml_backend_graph_compute(split_backend, &split->graph); + //wsp_ggml_backend_synchronize(split_backend); // necessary to measure compute time + } else { + // similar to wsp_ggml_backend_compare_graph_backend + for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { + struct wsp_ggml_tensor * t = split->graph.nodes[j0]; + + // check if the user needs data from this node + bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); + + int j1 = j0; + + // determine the range [j0, j1] of nodes that can be computed together + while (!need && j1 < split->graph.n_nodes - 1) { + t = split->graph.nodes[++j1]; + need = sched->callback_eval(t, true, sched->callback_eval_user_data); + } + + struct wsp_ggml_cgraph gv = wsp_ggml_graph_view(&split->graph, j0, j1 + 1); + + wsp_ggml_backend_graph_compute(split_backend, &gv); + + if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { + break; + } + + j0 = j1; + } + } uint64_t compute_end_us = wsp_ggml_time_us(); compute_us[split_backend_id] += compute_end_us - compute_start_us; } @@ -1094,26 +1380,41 @@ static void sched_reset(wsp_ggml_backend_sched_t sched) { for (int i = 0; i < sched->n_backends; i++) { wsp_ggml_tallocr_reset(sched->tallocs[i]); } + // reset state for the next run + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); + memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size); + memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size); + + sched->is_reset = true; } -wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, int n_backends) { +wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, wsp_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) { + WSP_GGML_ASSERT(n_backends > 0); WSP_GGML_ASSERT(n_backends <= WSP_GGML_MAX_BACKENDS); - struct wsp_ggml_backend_sched * sched = malloc(sizeof(struct wsp_ggml_backend_sched)); - memset(sched, 0, sizeof(struct wsp_ggml_backend_sched)); + struct wsp_ggml_backend_sched * sched = calloc(sizeof(struct wsp_ggml_backend_sched), 1); + + // initialize hash table + sched->hash_set = wsp_ggml_hash_set_new(graph_size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS); + sched->node_talloc = calloc(sizeof(sched->node_talloc[0]) * sched->hash_set.size, 1); + sched->node_copies = calloc(sizeof(sched->node_copies[0]) * sched->hash_set.size, 1); sched->n_backends = n_backends; for (int i = 0; i < n_backends; i++) { sched->backends[i] = backends[i]; + sched->bufts[i] = bufts ? bufts[i] : wsp_ggml_backend_get_default_buffer_type(backends[i]); } sched->galloc = wsp_ggml_gallocr_new(); // init measure allocs for each backend for (int i = 0; i < n_backends; i++) { - sched->tallocs[i] = wsp_ggml_tallocr_new_measure_from_backend(backends[i]); + sched->tallocs[i] = wsp_ggml_tallocr_new_measure_from_buft(sched->bufts[i]); } + sched_reset(sched); + return sched; } @@ -1125,6 +1426,7 @@ void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched) { wsp_ggml_tallocr_free(sched->tallocs[i]); } wsp_ggml_gallocr_free(sched->galloc); + wsp_ggml_free(sched->ctx); free(sched->hash_set.keys); free(sched->node_talloc); free(sched->node_copies); @@ -1132,12 +1434,7 @@ void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched) { } void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph) { - // initialize hash tables - size_t hash_size = measure_graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS; - sched->hash_set.size = hash_size; - sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size); - sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size); - sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size); + WSP_GGML_ASSERT(wsp_ggml_tallocr_is_measure(sched->tallocs[0])); // can only be initialized once sched_split_graph(sched, measure_graph); sched_alloc_splits(sched); @@ -1146,28 +1443,47 @@ void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct for (int i = 0; i < sched->n_backends; i++) { size_t size = wsp_ggml_tallocr_max_size(sched->tallocs[i]); wsp_ggml_tallocr_free(sched->tallocs[i]); - sched->tallocs[i] = wsp_ggml_tallocr_new_from_backend(sched->backends[i], size); + sched->tallocs[i] = wsp_ggml_tallocr_new_from_buft(sched->bufts[i], size); } sched_reset(sched); } void wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) { - WSP_GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS); + WSP_GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS); + + if (!sched->is_reset) { + sched_reset(sched); + } sched_split_graph(sched, graph); sched_alloc_splits(sched); sched_compute_splits(sched); +} + +void wsp_ggml_backend_sched_reset(wsp_ggml_backend_sched_t sched) { sched_reset(sched); } + +void wsp_ggml_backend_sched_set_eval_callback(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_sched_eval_callback callback, void * user_data) { + sched->callback_eval = callback; + sched->callback_eval_user_data = user_data; +} + +int wsp_ggml_backend_sched_get_n_splits(wsp_ggml_backend_sched_t sched) { + return sched->n_splits; +} + wsp_ggml_tallocr_t wsp_ggml_backend_sched_get_tallocr(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) { int backend_index = sched_backend_prio(sched, backend); + WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); return sched->tallocs[backend_index]; } wsp_ggml_backend_buffer_t wsp_ggml_backend_sched_get_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) { int backend_index = sched_backend_prio(sched, backend); + WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); return wsp_ggml_tallocr_get_buffer(sched->tallocs[backend_index]); } @@ -1177,10 +1493,19 @@ void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, str node_allocr(node) = sched->tallocs[backend_index]; } +wsp_ggml_backend_t wsp_ggml_backend_sched_get_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) { + wsp_ggml_tallocr_t allocr = node_allocr(node); + if (allocr == NULL) { + return NULL; + } + return get_allocr_backend(sched, allocr); +} + // utils + void wsp_ggml_backend_view_init(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) { WSP_GGML_ASSERT(tensor->buffer == NULL); - WSP_GGML_ASSERT(tensor->data == NULL); + //WSP_GGML_ASSERT(tensor->data == NULL); // views of pre-allocated tensors may have the data set in wsp_ggml_new_tensor, but still need to be initialized by the backend WSP_GGML_ASSERT(tensor->view_src != NULL); WSP_GGML_ASSERT(tensor->view_src->buffer != NULL); WSP_GGML_ASSERT(tensor->view_src->data != NULL); @@ -1246,6 +1571,7 @@ static void graph_init_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml struct wsp_ggml_tensor * dst = node_copies[id]; if (dst->view_src != NULL) { + graph_init_tensor(hash_set, node_copies, node_init, src->view_src); wsp_ggml_backend_view_init(dst->view_src->buffer, dst); } else { @@ -1279,6 +1605,21 @@ struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_ struct wsp_ggml_context * ctx_allocated = wsp_ggml_init(params); struct wsp_ggml_context * ctx_unallocated = wsp_ggml_init(params); + if (ctx_allocated == NULL || ctx_unallocated == NULL) { + fprintf(stderr, "failed to allocate context for graph copy\n"); + free(hash_set.keys); + free(node_copies); + free(node_init); + wsp_ggml_free(ctx_allocated); + wsp_ggml_free(ctx_unallocated); + return (struct wsp_ggml_backend_graph_copy) { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } + // dup nodes for (int i = 0; i < graph->n_nodes; i++) { struct wsp_ggml_tensor * node = graph->nodes[i]; @@ -1287,6 +1628,20 @@ struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_ // allocate nodes wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); + if (buffer == NULL) { + fprintf(stderr, "failed to allocate buffer for graph copy\n"); + free(hash_set.keys); + free(node_copies); + free(node_init); + wsp_ggml_free(ctx_allocated); + wsp_ggml_free(ctx_unallocated); + return (struct wsp_ggml_backend_graph_copy) { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } //printf("copy buffer size: %zu MB\n", wsp_ggml_backend_buffer_get_size(buffer) / 1024 / 1024); @@ -1323,8 +1678,12 @@ void wsp_ggml_backend_graph_copy_free(struct wsp_ggml_backend_graph_copy copy) { wsp_ggml_free(copy.ctx_unallocated); } -void wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data) { +bool wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data) { struct wsp_ggml_backend_graph_copy copy = wsp_ggml_backend_graph_copy(backend2, graph); + if (copy.buffer == NULL) { + return false; + } + struct wsp_ggml_cgraph * g1 = graph; struct wsp_ggml_cgraph * g2 = copy.graph; @@ -1354,4 +1713,6 @@ void wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggm } wsp_ggml_backend_graph_copy_free(copy); + + return true; } diff --git a/cpp/ggml-backend.h b/cpp/ggml-backend.h index 3cdc7e8..8f23421 100644 --- a/cpp/ggml-backend.h +++ b/cpp/ggml-backend.h @@ -17,19 +17,31 @@ extern "C" { // // buffer type - WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size); - WSP_GGML_API size_t wsp_ggml_backend_buft_get_alignment (wsp_ggml_backend_buffer_type_t buft); - WSP_GGML_API size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor); - WSP_GGML_API bool wsp_ggml_backend_buft_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend); + WSP_GGML_API const char * wsp_ggml_backend_buft_name (wsp_ggml_backend_buffer_type_t buft); + WSP_GGML_API WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer (wsp_ggml_backend_buffer_type_t buft, size_t size); + WSP_GGML_API size_t wsp_ggml_backend_buft_get_alignment (wsp_ggml_backend_buffer_type_t buft); + WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_backend_buft_get_alloc_size (wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor); + WSP_GGML_API bool wsp_ggml_backend_buft_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend); + WSP_GGML_API bool wsp_ggml_backend_buft_is_host (wsp_ggml_backend_buffer_type_t buft); // buffer - WSP_GGML_API void wsp_ggml_backend_buffer_free (wsp_ggml_backend_buffer_t buffer); - WSP_GGML_API void * wsp_ggml_backend_buffer_get_base (wsp_ggml_backend_buffer_t buffer); - WSP_GGML_API size_t wsp_ggml_backend_buffer_get_size (wsp_ggml_backend_buffer_t buffer); - WSP_GGML_API void wsp_ggml_backend_buffer_init_tensor (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); - WSP_GGML_API size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer); - WSP_GGML_API size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); - WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_type(wsp_ggml_backend_buffer_t buffer); + enum wsp_ggml_backend_buffer_usage { + WSP_GGML_BACKEND_BUFFER_USAGE_ANY = 0, + WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, + }; + + WSP_GGML_API const char * wsp_ggml_backend_buffer_name (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API void wsp_ggml_backend_buffer_free (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API void * wsp_ggml_backend_buffer_get_base (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API size_t wsp_ggml_backend_buffer_get_size (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API WSP_GGML_CALL void wsp_ggml_backend_buffer_init_tensor (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); + WSP_GGML_API size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); + WSP_GGML_API void wsp_ggml_backend_buffer_clear (wsp_ggml_backend_buffer_t buffer, uint8_t value); + WSP_GGML_API bool wsp_ggml_backend_buffer_is_host (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API void wsp_ggml_backend_buffer_set_usage (wsp_ggml_backend_buffer_t buffer, enum wsp_ggml_backend_buffer_usage usage); + WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_get_type (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API void wsp_ggml_backend_buffer_reset (wsp_ggml_backend_buffer_t buffer); // // Backend @@ -46,8 +58,8 @@ extern "C" { WSP_GGML_API void wsp_ggml_backend_tensor_set_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); WSP_GGML_API void wsp_ggml_backend_tensor_get_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); - WSP_GGML_API void wsp_ggml_backend_tensor_set( struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); - WSP_GGML_API void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); + WSP_GGML_API WSP_GGML_CALL void wsp_ggml_backend_tensor_set( struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + WSP_GGML_API WSP_GGML_CALL void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); WSP_GGML_API void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend); @@ -55,7 +67,7 @@ extern "C" { WSP_GGML_API void wsp_ggml_backend_graph_plan_free (wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); WSP_GGML_API void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); - WSP_GGML_API void wsp_ggml_backend_graph_compute (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); + WSP_GGML_API bool wsp_ggml_backend_graph_compute (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); WSP_GGML_API bool wsp_ggml_backend_supports_op (wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op); // tensor copy between different backends @@ -68,13 +80,17 @@ extern "C" { WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void); - WSP_GGML_API bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend); - WSP_GGML_API void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads); + WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_backend_is_cpu (wsp_ggml_backend_t backend); + WSP_GGML_API void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads); // Create a backend buffer from an existing pointer - WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + WSP_GGML_API WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); - WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void); + WSP_GGML_API WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void); + +#ifdef WSP_GGML_USE_CPU_HBM + WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_hbm_buffer_type(void); +#endif // // Backend registry @@ -132,24 +148,36 @@ extern "C" { struct wsp_ggml_backend_sched; typedef struct wsp_ggml_backend_sched * wsp_ggml_backend_sched_t; - // Initialize a backend scheduler - WSP_GGML_API wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, int n_backends); - - WSP_GGML_API void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched); + // when ask == true, the scheduler wants to know if the user wants to observe this node + // this allows the scheduler to batch nodes together in order to evaluate them in a single call + // + // when ask == false, the scheduler is passing the node tensor to the user for observation + // if the user returns false, the scheduler will cancel the graph compute + // + typedef bool (*wsp_ggml_backend_sched_eval_callback)(struct wsp_ggml_tensor * t, bool ask, void * user_data); + // Initialize a backend scheduler + WSP_GGML_API wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, wsp_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size); + WSP_GGML_API void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph - WSP_GGML_API void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph); + WSP_GGML_API void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph); + // Get the number of splits of the last graph + WSP_GGML_API int wsp_ggml_backend_sched_get_n_splits(wsp_ggml_backend_sched_t sched); WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_backend_sched_get_tallocr(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend); WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_sched_get_buffer (wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend); - WSP_GGML_API void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend); + WSP_GGML_API void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend); + WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_sched_get_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node); + + // Allocate and compute graph on the backend scheduler + WSP_GGML_API void wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph); - // Allocate a graph on the backend scheduler - WSP_GGML_API void wsp_ggml_backend_sched_graph_compute( - wsp_ggml_backend_sched_t sched, - struct wsp_ggml_cgraph * graph); + // Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs + WSP_GGML_API void wsp_ggml_backend_sched_reset(wsp_ggml_backend_sched_t sched); + // Set a callback to be called for each resulting node during graph compute + WSP_GGML_API void wsp_ggml_backend_sched_set_eval_callback(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_sched_eval_callback callback, void * user_data); // // Utils @@ -166,10 +194,10 @@ extern "C" { WSP_GGML_API struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * graph); WSP_GGML_API void wsp_ggml_backend_graph_copy_free(struct wsp_ggml_backend_graph_copy copy); - typedef bool (*wsp_ggml_backend_eval_callback)(int node_index, struct wsp_ggml_tensor * t1, struct wsp_ggml_tensor * t2, void * user_data); + typedef bool (*WSP_GGML_CALL wsp_ggml_backend_eval_callback)(int node_index, struct wsp_ggml_tensor * t1, struct wsp_ggml_tensor * t2, void * user_data); // Compare the output of two backends - WSP_GGML_API void wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data); + WSP_GGML_API bool wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data); // Tensor initialization WSP_GGML_API void wsp_ggml_backend_tensor_alloc(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, void * addr); diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h index ab73e85..cf097f1 100644 --- a/cpp/ggml-impl.h +++ b/cpp/ggml-impl.h @@ -5,6 +5,7 @@ // GGML internal header #include +#include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ #include #include #include // memcpy @@ -227,6 +228,8 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) { #define WSP_GGML_HASHTABLE_FULL ((size_t)-1) #define WSP_GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) +struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size); + bool wsp_ggml_hash_contains (const struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key); // returns WSP_GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted diff --git a/cpp/ggml-metal-whisper.metal b/cpp/ggml-metal-whisper.metal index fe0ada4..029578d 100644 --- a/cpp/ggml-metal-whisper.metal +++ b/cpp/ggml-metal-whisper.metal @@ -59,26 +59,26 @@ kernel void kernel_add( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, constant int64_t & offs, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -109,26 +109,26 @@ kernel void kernel_mul( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -158,26 +158,26 @@ kernel void kernel_div( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -205,7 +205,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(28)]], + constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -214,7 +214,7 @@ kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(28)]], + constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src1[tpig % nb]; } @@ -223,7 +223,7 @@ kernel void kernel_div_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(28)]], + constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] / src1[tpig % nb]; } @@ -307,26 +307,26 @@ kernel void kernel_sum_rows( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, uint3 tpig[[thread_position_in_grid]]) { int64_t i3 = tpig.z; int64_t i2 = tpig.y; @@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre #define N_SIMDGROUP 2 // number of SIMD groups in a thread group //Note: This is a template, but strictly speaking it only applies to // quantizations where the block size is 32. It also does not -// giard against the number of rows not being divisible by +// guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. template void mul_vec_q_n_f32_impl( @@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne10, + constant int64_t & ne11, constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); @@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); @@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); @@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32( const int64_t i3 = n / (ne2*ne1*ne0); const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const int64_t k = i3*ne3 + i2; float m_k; @@ -1702,8 +1738,9 @@ kernel void kernel_rope( dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { + for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { + if (ic < n_dims) { + const int64_t ib = 0; // simplified from `(ib * n_dims + ic) * inv_ndims` const float cur_rot = inv_ndims*ic - ib; @@ -1722,6 +1759,14 @@ kernel void kernel_rope( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -2401,21 +2446,18 @@ typedef struct { } block_q6_K; // 210 bytes / block -static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { - uchar4 r; - if (j < 4) { - r[0] = q[j+0] & 63; - r[2] = q[j+1] & 63; - r[1] = q[j+4] & 63; - r[3] = q[j+5] & 63; - } else { - r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); - r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); - } - return r; -} +typedef struct { + half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +// 66 bytes / block for QK_K = 256, so 2.0625 bpw + +typedef struct { + half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +// 74 bytes / block for QK_K = 256, so 2.3125 bpw //====================================== dot products ========================= @@ -2575,14 +2617,21 @@ kernel void kernel_mul_mv_q2_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2832,14 +2881,21 @@ kernel void kernel_mul_mv_q3_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2975,8 +3031,8 @@ void kernel_mul_mv_q4_K_f32_impl( constant uint & r2, constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int ix = tiisg/4; // 0...7 const int it = tiisg%4; // 0...3 @@ -2985,7 +3041,7 @@ void kernel_mul_mv_q4_K_f32_impl( const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; const int ib_row = first_row * nb; const uint i12 = im%ne12; @@ -3051,7 +3107,7 @@ void kernel_mul_mv_q4_K_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -3063,14 +3119,21 @@ kernel void kernel_mul_mv_q4_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3262,14 +3325,21 @@ kernel void kernel_mul_mv_q5_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3389,14 +3459,21 @@ kernel void kernel_mul_mv_q6_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3404,51 +3481,540 @@ kernel void kernel_mul_mv_q6_K_f32( kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } -//============================= templates and their specializations ============================= - -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} +// ======================= "True" 2-bit + +constexpr constant static uint64_t iq2xxs_grid[256] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +}; -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} +constexpr constant static uint64_t iq2xs_grid[512] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +}; -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; +constexpr constant static uint8_t ksigns_iq2xs[128] = { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +}; - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} +constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; +void kernel_mul_mv_iq2_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + +#if QK_K == 256 + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += nb*sizeof(block_iq2_xxs)/2; + q2 += nb*sizeof(block_iq2_xxs)/2; + } + + y4 += 32 * 32; + } +#else + // TODO +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + +#if QK_K == 256 + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); + } + + y4 += 32 * 32; + } +#else + // TODO +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +//============================= templates and their specializations ============================= + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; } } @@ -3514,7 +4080,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg device const int8_t * qs = ((device const int8_t *)xb->qs); const half d = xb->d; - for (int i=0;i<16;i++) { + for (int i = 0; i < 16; i++) { reg[i/4][i%4] = (qs[i + 16*il] * d); } } @@ -3556,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; il = (il/2) & 3; const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); @@ -3624,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg uint8_t ul = 1 << (il/2); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; + const float d = il < 2 ? xb->d : xb->d / 16.f; const float min = xb->dmin; const float dl = d * sc[0]; const float ml = min * sc[1]; @@ -3657,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg #if QK_K == 256 ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; + float sc = scales[(il%2) + 2 * ((il/2))]; il = (il/2) & 3; #else ql = ql + 16 * (il&1); - half sc = scales[il]; + float sc = scales[il]; #endif const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; for (int i = 0; i < 16; ++i) { const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); @@ -3675,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg } } +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + template kernel void kernel_get_rows( device const void * src0, @@ -3755,48 +4367,212 @@ kernel void kernel_get_rows_f16( const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_i32( + device const void * src0, + device const char * src1, + device int32_t * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +void kernel_mul_mm_impl(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - const int64_t i02 = i11; + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } } } -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - -// each block_q contains 16*nl weights +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +void kernel_mul_mm_id_impl( + device const uchar * src0, + device const uchar * src1, + thread short * src1ids, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + int64_t ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); @@ -3805,6 +4581,8 @@ void kernel_mul_mm_impl(device const uchar * src0, const uint r1 = tgpig.x; const uint im = tgpig.z; + if (r1 * BLOCK_SIZE_N >= ne1) return; + // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; @@ -3831,7 +4609,7 @@ void kernel_mul_mm_impl(device const uchar * src0, device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { @@ -3840,7 +4618,6 @@ void kernel_mul_mm_impl(device const uchar * src0, dequantize_func(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); - #pragma unroll(16) for (int i = 0; i < 16; i++) { *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ @@ -3859,14 +4636,11 @@ void kernel_mul_mm_impl(device const uchar * src0, threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - #pragma unroll(4) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) for (int i = 0; i < 4; i++) { simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) for (int i = 0; i < 2; i++) { simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); } @@ -3874,21 +4648,13 @@ void kernel_mul_mm_impl(device const uchar * src0, lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - #pragma unroll(8) for (int i = 0; i < 8; i++){ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); } } } - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix + { threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; @@ -3898,11 +4664,11 @@ void kernel_mul_mm_impl(device const uchar * src0, threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; if (sgitg == 0) { for (int i = 0; i < n_rows; i++) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); } } } @@ -3915,12 +4681,12 @@ kernel void kernel_mul_mm(device const uchar * src0, device float * dst, constant int64_t & ne00, constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3955,20 +4721,20 @@ template( - src0[id], - src1 + bid*nb11, - (device float *) (dst + bid*nb1), + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { + src1ids[_ne1++] = i1; + } + } + + kernel_mul_mm_id_impl( + src0s[id], + src1, + src1ids, + dst, ne00, ne02, nb01, @@ -4005,7 +4781,7 @@ kernel void kernel_mul_mm_id( nb11, nb12, ne0, - ne1, + _ne1, r2, r3, shared_memory, @@ -4050,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; // // matrix-matrix multiplication @@ -4061,12 +4839,12 @@ typedef void (mat_mm_t)( device float * dst, constant int64_t & ne00, constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4086,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -4094,20 +4874,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; // // matrix-vector multiplication @@ -4143,8 +4925,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu kernel void kernel_mul_mv_id_f32_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4160,7 +4942,7 @@ kernel void kernel_mul_mv_id_f32_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4187,7 +4969,7 @@ kernel void kernel_mul_mv_id_f32_f32( kernel_mul_mv_f32_f32_impl( src0[id], src1 + bid*nb11, - (device float *) (dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4212,8 +4994,8 @@ kernel void kernel_mul_mv_id_f32_f32( kernel void kernel_mul_mv_id_f16_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4229,7 +5011,7 @@ kernel void kernel_mul_mv_id_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4256,7 +5038,7 @@ kernel void kernel_mul_mv_id_f16_f32( kernel_mul_mv_f16_f32_impl( src0[id], src1 + bid*nb11, - (device float *) (dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4281,8 +5063,8 @@ kernel void kernel_mul_mv_id_f16_f32( kernel void kernel_mul_mv_id_q8_0_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4298,7 +5080,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4325,7 +5107,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( kernel_mul_mv_q8_0_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4344,8 +5126,8 @@ kernel void kernel_mul_mv_id_q8_0_f32( kernel void kernel_mul_mv_id_q4_0_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4361,7 +5143,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4388,7 +5170,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( mul_vec_q_n_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4407,8 +5189,8 @@ kernel void kernel_mul_mv_id_q4_0_f32( kernel void kernel_mul_mv_id_q4_1_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4424,7 +5206,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4451,7 +5233,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( mul_vec_q_n_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4470,8 +5252,8 @@ kernel void kernel_mul_mv_id_q4_1_f32( kernel void kernel_mul_mv_id_q5_0_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4487,7 +5269,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4514,7 +5296,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( mul_vec_q_n_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4533,8 +5315,8 @@ kernel void kernel_mul_mv_id_q5_0_f32( kernel void kernel_mul_mv_id_q5_1_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4550,7 +5332,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4577,7 +5359,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( mul_vec_q_n_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4596,8 +5378,8 @@ kernel void kernel_mul_mv_id_q5_1_f32( kernel void kernel_mul_mv_id_q2_K_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4613,7 +5395,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4640,7 +5422,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( kernel_mul_mv_q2_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4659,8 +5441,8 @@ kernel void kernel_mul_mv_id_q2_K_f32( kernel void kernel_mul_mv_id_q3_K_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4676,7 +5458,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4703,7 +5485,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( kernel_mul_mv_q3_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4722,8 +5504,8 @@ kernel void kernel_mul_mv_id_q3_K_f32( kernel void kernel_mul_mv_id_q4_K_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4739,7 +5521,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4766,7 +5548,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( kernel_mul_mv_q4_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4785,8 +5567,8 @@ kernel void kernel_mul_mv_id_q4_K_f32( kernel void kernel_mul_mv_id_q5_K_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4802,7 +5584,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4829,7 +5611,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( kernel_mul_mv_q5_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4848,8 +5630,8 @@ kernel void kernel_mul_mv_id_q5_K_f32( kernel void kernel_mul_mv_id_q6_K_f32( device const char * ids, device const char * src1, - device uchar * dst, - constant int64_t & nbi1, + device float * dst, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4865,7 +5647,7 @@ kernel void kernel_mul_mv_id_q6_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4892,7 +5674,136 @@ kernel void kernel_mul_mv_id_q6_K_f32( kernel_mul_mv_q6_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] +kernel void kernel_mul_mv_id_iq2_xxs_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq2_xxs_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + shared_values, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_iq2_xs_f32")]] +kernel void kernel_mul_mv_id_iq2_xs_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq2_xs_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, ne00, ne01, ne02, @@ -4902,6 +5813,7 @@ kernel void kernel_mul_mv_id_q6_K_f32( ne1, r2, r3, + shared_values, tgpig, tiisg, sgitg); diff --git a/cpp/ggml-metal.h b/cpp/ggml-metal.h index 04411ac..70ffd34 100644 --- a/cpp/ggml-metal.h +++ b/cpp/ggml-metal.h @@ -36,70 +36,22 @@ struct wsp_ggml_cgraph; extern "C" { #endif -// -// internal API -// temporary exposed to user-code -// - -struct wsp_ggml_metal_context; - -void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data); - -// number of command buffers to use -struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb); -void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx); - -void * wsp_ggml_metal_host_malloc(size_t n); -void wsp_ggml_metal_host_free (void * data); - -// set the number of command buffers to use -void wsp_ggml_metal_set_n_cb(struct wsp_ggml_metal_context * ctx, int n_cb); - -// creates a mapping between a host memory buffer and a device memory buffer -// - make sure to map all buffers used in the graph before calling wsp_ggml_metal_graph_compute -// - the mapping is used during computation to determine the arguments of the compute kernels -// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal -// - max_size specifies the maximum size of a tensor and is used to create shared views such -// that it is guaranteed that the tensor will fit in at least one of the views -// -bool wsp_ggml_metal_add_buffer( - struct wsp_ggml_metal_context * ctx, - const char * name, - void * data, - size_t size, - size_t max_size); - -// set data from host memory into the device -void wsp_ggml_metal_set_tensor(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t); - -// get data from the device into host memory -void wsp_ggml_metal_get_tensor(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t); - -// try to find operations that can be run concurrently in the graph -// you should run it again if the topology of your graph changes -void wsp_ggml_metal_graph_find_concurrency(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_cgraph * gf, bool check_mem); - -// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized -int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx); - -// output the concur_list for wsp_ggml_alloc -int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx); - -// same as wsp_ggml_graph_compute but uses Metal -// creates gf->n_threads command buffers in parallel -void wsp_ggml_metal_graph_compute(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_cgraph * gf); - // // backend API // user-code should use only these functions // +WSP_GGML_API void wsp_ggml_backend_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data); + WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void); WSP_GGML_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend); +WSP_GGML_API WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); + WSP_GGML_API void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb); -WSP_GGML_API wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void); + +WSP_GGML_API WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void); // helper to check if the device supports a specific family // ideally, the user code should be doing these checks diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index 8d00eb8..7edd6ed 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -24,7 +24,7 @@ #define UNUSED(x) (void)(x) -#define WSP_GGML_MAX_CONCUR (2*WSP_GGML_DEFAULT_GRAPH_SIZE) +#define WSP_GGML_METAL_MAX_KERNELS 256 struct wsp_ggml_metal_buffer { const char * name; @@ -35,6 +35,134 @@ id metal; }; +struct wsp_ggml_metal_kernel { + id function; + id pipeline; +}; + +enum wsp_ggml_metal_kernel_type { + WSP_GGML_METAL_KERNEL_TYPE_ADD, + WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW, + WSP_GGML_METAL_KERNEL_TYPE_MUL, + WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW, + WSP_GGML_METAL_KERNEL_TYPE_DIV, + WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW, + WSP_GGML_METAL_KERNEL_TYPE_SCALE, + WSP_GGML_METAL_KERNEL_TYPE_SCALE_4, + WSP_GGML_METAL_KERNEL_TYPE_TANH, + WSP_GGML_METAL_KERNEL_TYPE_RELU, + WSP_GGML_METAL_KERNEL_TYPE_GELU, + WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK, + WSP_GGML_METAL_KERNEL_TYPE_SILU, + WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX, + WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, + WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, + WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, + WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, + WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, + WSP_GGML_METAL_KERNEL_TYPE_NORM, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, + //WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, + //WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, + //WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, + WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, + WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32, + WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16, + WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32, + WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16, + WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, + WSP_GGML_METAL_KERNEL_TYPE_PAD_F32, + WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, + WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, + WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, + WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, + WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, + WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, + //WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, + //WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, + WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, + WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, + WSP_GGML_METAL_KERNEL_TYPE_CONCAT, + WSP_GGML_METAL_KERNEL_TYPE_SQR, + WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, + + WSP_GGML_METAL_KERNEL_TYPE_COUNT +}; + struct wsp_ggml_metal_context { int n_cb; @@ -42,131 +170,15 @@ id queue; id library; - id command_buffers [WSP_GGML_METAL_MAX_COMMAND_BUFFERS]; - id command_encoders[WSP_GGML_METAL_MAX_COMMAND_BUFFERS]; - dispatch_queue_t d_queue; int n_buffers; struct wsp_ggml_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS]; - int concur_list[WSP_GGML_MAX_CONCUR]; - int concur_list_len; - - // custom kernels -#define WSP_GGML_METAL_DECL_KERNEL(name) \ - id function_##name; \ - id pipeline_##name - - WSP_GGML_METAL_DECL_KERNEL(add); - WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast - WSP_GGML_METAL_DECL_KERNEL(mul); - WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast - WSP_GGML_METAL_DECL_KERNEL(div); - WSP_GGML_METAL_DECL_KERNEL(div_row); - WSP_GGML_METAL_DECL_KERNEL(scale); - WSP_GGML_METAL_DECL_KERNEL(scale_4); - WSP_GGML_METAL_DECL_KERNEL(tanh); - WSP_GGML_METAL_DECL_KERNEL(relu); - WSP_GGML_METAL_DECL_KERNEL(gelu); - WSP_GGML_METAL_DECL_KERNEL(gelu_quick); - WSP_GGML_METAL_DECL_KERNEL(silu); - WSP_GGML_METAL_DECL_KERNEL(soft_max); - WSP_GGML_METAL_DECL_KERNEL(soft_max_4); - WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf); - WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf_8); - WSP_GGML_METAL_DECL_KERNEL(get_rows_f32); - WSP_GGML_METAL_DECL_KERNEL(get_rows_f16); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_K); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K); - WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K); - WSP_GGML_METAL_DECL_KERNEL(rms_norm); - WSP_GGML_METAL_DECL_KERNEL(group_norm); - WSP_GGML_METAL_DECL_KERNEL(norm); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32); - //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32); - //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row); - //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32); - WSP_GGML_METAL_DECL_KERNEL(rope_f32); - WSP_GGML_METAL_DECL_KERNEL(rope_f16); - WSP_GGML_METAL_DECL_KERNEL(alibi_f32); - WSP_GGML_METAL_DECL_KERNEL(im2col_f16); - WSP_GGML_METAL_DECL_KERNEL(upscale_f32); - WSP_GGML_METAL_DECL_KERNEL(pad_f32); - WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc); - WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc); - WSP_GGML_METAL_DECL_KERNEL(leaky_relu_f32); - WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16); - WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32); - WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); - WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0); - WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1); - //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); - //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); - WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16); - WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f32); - WSP_GGML_METAL_DECL_KERNEL(concat); - WSP_GGML_METAL_DECL_KERNEL(sqr); - WSP_GGML_METAL_DECL_KERNEL(sum_rows); - -#undef WSP_GGML_METAL_DECL_KERNEL + struct wsp_ggml_metal_kernel kernels[WSP_GGML_METAL_MAX_KERNELS]; + + bool support_simdgroup_reduction; + bool support_simdgroup_mm; }; // MSL code @@ -180,14 +192,16 @@ @interface WSPGGMLMetalClass : NSObject @implementation WSPGGMLMetalClass @end -wsp_ggml_log_callback wsp_ggml_metal_log_callback = NULL; -void * wsp_ggml_metal_log_user_data = NULL; +static void wsp_ggml_metal_default_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) { + fprintf(stderr, "%s", msg); -void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) { - wsp_ggml_metal_log_callback = log_callback; - wsp_ggml_metal_log_user_data = user_data; + UNUSED(level); + UNUSED(user_data); } +wsp_ggml_log_callback wsp_ggml_metal_log_callback = wsp_ggml_metal_default_log_callback; +void * wsp_ggml_metal_log_user_data = NULL; + WSP_GGML_ATTRIBUTE_FORMAT(2, 3) static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){ if (wsp_ggml_metal_log_callback != NULL) { @@ -210,24 +224,33 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma } } -struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) { - WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__); +static void * wsp_ggml_metal_host_malloc(size_t n) { + void * data = NULL; + const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); + if (result != 0) { + WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); + return NULL; + } + + return data; +} - id device; - NSString * s; +static struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) { + WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__); -#if TARGET_OS_OSX +#if TARGET_OS_OSX && !WSP_GGML_METAL_NDEBUG // Show all the Metal device instances in the system NSArray * devices = MTLCopyAllDevices(); - for (device in devices) { - s = [device name]; + for (id device in devices) { + NSString * s = [device name]; WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]); } + [devices release]; // since it was created by a *Copy* C method #endif // Pick and show default Metal device - device = MTLCreateSystemDefaultDevice(); - s = [device name]; + id device = MTLCreateSystemDefaultDevice(); + NSString * s = [device name]; WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); // Configure context @@ -236,7 +259,6 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS); ctx->queue = [ctx->device newCommandQueue]; ctx->n_buffers = 0; - ctx->concur_list_len = 0; ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); @@ -251,6 +273,7 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma NSError * error = nil; NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"]; if (libPath != nil) { + // pre-compiled library found NSURL * libURL = [NSURL fileURLWithPath:libPath]; WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; @@ -278,12 +301,21 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma return NULL; } - MTLCompileOptions* options = nil; + @autoreleasepool { + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + #ifdef WSP_GGML_QKK_64 - options = [MTLCompileOptions new]; - options.preprocessorMacros = @{ @"QK_K" : @(64) }; + prep[@"QK_K"] = @(64); #endif - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + + //[options setFastMathEnabled:false]; + + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + } } if (error) { @@ -292,22 +324,51 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma } } -#if TARGET_OS_OSX // print MTL GPU family: WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); + const NSInteger MTLGPUFamilyMetal3 = 5001; + // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf - for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { - if ([ctx->device supportsFamily:i]) { - WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); - break; + { + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([ctx->device supportsFamily:i]) { + WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { + if ([ctx->device supportsFamily:i]) { + WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) { + if ([ctx->device supportsFamily:i]) { + WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i); + break; + } } } + ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7]; + ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3]; + + ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7]; + + WSP_GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false"); + WSP_GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false"); WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); - WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6); + +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6); + } +#elif TARGET_OS_OSX if (ctx->device.maxTransferRate != 0) { WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6); } else { @@ -319,286 +380,177 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * forma { NSError * error = nil; + for (int i = 0; i < WSP_GGML_METAL_MAX_KERNELS; ++i) { + ctx->kernels[i].function = nil; + ctx->kernels[i].pipeline = nil; + } + /* - WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ - (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ - (int) ctx->pipeline_##name.threadExecutionWidth); \ + WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ */ -#define WSP_GGML_METAL_ADD_KERNEL(name) \ - ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ - if (error) { \ - WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ - return NULL; \ +#define WSP_GGML_METAL_ADD_KERNEL(e, name, supported) \ + if (supported) { \ + struct wsp_ggml_metal_kernel * kernel = &ctx->kernels[e]; \ + kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ + kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + if (error) { \ + WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + return NULL; \ + } \ + } else { \ + WSP_GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ } - WSP_GGML_METAL_ADD_KERNEL(add); - WSP_GGML_METAL_ADD_KERNEL(add_row); - WSP_GGML_METAL_ADD_KERNEL(mul); - WSP_GGML_METAL_ADD_KERNEL(mul_row); - WSP_GGML_METAL_ADD_KERNEL(div); - WSP_GGML_METAL_ADD_KERNEL(div_row); - WSP_GGML_METAL_ADD_KERNEL(scale); - WSP_GGML_METAL_ADD_KERNEL(scale_4); - WSP_GGML_METAL_ADD_KERNEL(tanh); - WSP_GGML_METAL_ADD_KERNEL(relu); - WSP_GGML_METAL_ADD_KERNEL(gelu); - WSP_GGML_METAL_ADD_KERNEL(gelu_quick); - WSP_GGML_METAL_ADD_KERNEL(silu); - WSP_GGML_METAL_ADD_KERNEL(soft_max); - WSP_GGML_METAL_ADD_KERNEL(soft_max_4); - WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf); - WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf_8); - WSP_GGML_METAL_ADD_KERNEL(get_rows_f32); - WSP_GGML_METAL_ADD_KERNEL(get_rows_f16); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_0); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_1); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_K); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_K); - WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K); - WSP_GGML_METAL_ADD_KERNEL(rms_norm); - WSP_GGML_METAL_ADD_KERNEL(group_norm); - WSP_GGML_METAL_ADD_KERNEL(norm); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32); - //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32); - //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row); - //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32); - if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { - WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32); - } - WSP_GGML_METAL_ADD_KERNEL(rope_f32); - WSP_GGML_METAL_ADD_KERNEL(rope_f16); - WSP_GGML_METAL_ADD_KERNEL(alibi_f32); - WSP_GGML_METAL_ADD_KERNEL(im2col_f16); - WSP_GGML_METAL_ADD_KERNEL(upscale_f32); - WSP_GGML_METAL_ADD_KERNEL(pad_f32); - WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc); - WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc); - WSP_GGML_METAL_ADD_KERNEL(leaky_relu_f32); - WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16); - WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32); - WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); - WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0); - WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1); - //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); - //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); - WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16); - WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f32); - WSP_GGML_METAL_ADD_KERNEL(concat); - WSP_GGML_METAL_ADD_KERNEL(sqr); - WSP_GGML_METAL_ADD_KERNEL(sum_rows); - -#undef WSP_GGML_METAL_ADD_KERNEL + // simd_sum and simd_max requires MTLGPUFamilyApple7 + + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD, add, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL, mul, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV, div, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RELU, relu, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SILU, silu, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_NORM, norm, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; } -void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { +static void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); -#define WSP_GGML_METAL_DEL_KERNEL(name) \ - - WSP_GGML_METAL_DEL_KERNEL(add); - WSP_GGML_METAL_DEL_KERNEL(add_row); - WSP_GGML_METAL_DEL_KERNEL(mul); - WSP_GGML_METAL_DEL_KERNEL(mul_row); - WSP_GGML_METAL_DEL_KERNEL(div); - WSP_GGML_METAL_DEL_KERNEL(div_row); - WSP_GGML_METAL_DEL_KERNEL(scale); - WSP_GGML_METAL_DEL_KERNEL(scale_4); - WSP_GGML_METAL_DEL_KERNEL(tanh); - WSP_GGML_METAL_DEL_KERNEL(relu); - WSP_GGML_METAL_DEL_KERNEL(gelu); - WSP_GGML_METAL_DEL_KERNEL(gelu_quick); - WSP_GGML_METAL_DEL_KERNEL(silu); - WSP_GGML_METAL_DEL_KERNEL(soft_max); - WSP_GGML_METAL_DEL_KERNEL(soft_max_4); - WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf); - WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf_8); - WSP_GGML_METAL_DEL_KERNEL(get_rows_f32); - WSP_GGML_METAL_DEL_KERNEL(get_rows_f16); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_K); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K); - WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K); - WSP_GGML_METAL_DEL_KERNEL(rms_norm); - WSP_GGML_METAL_DEL_KERNEL(group_norm); - WSP_GGML_METAL_DEL_KERNEL(norm); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32); - //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32); - //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row); - //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32); - if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { - WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32); - } - WSP_GGML_METAL_DEL_KERNEL(rope_f32); - WSP_GGML_METAL_DEL_KERNEL(rope_f16); - WSP_GGML_METAL_DEL_KERNEL(alibi_f32); - WSP_GGML_METAL_DEL_KERNEL(im2col_f16); - WSP_GGML_METAL_DEL_KERNEL(upscale_f32); - WSP_GGML_METAL_DEL_KERNEL(pad_f32); - WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc); - WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc); - WSP_GGML_METAL_DEL_KERNEL(leaky_relu_f32); - WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16); - WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32); - WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); - WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0); - WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1); - //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); - //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); - WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16); - WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f32); - WSP_GGML_METAL_DEL_KERNEL(concat); - WSP_GGML_METAL_DEL_KERNEL(sqr); - WSP_GGML_METAL_DEL_KERNEL(sum_rows); - -#undef WSP_GGML_METAL_DEL_KERNEL free(ctx); } -void * wsp_ggml_metal_host_malloc(size_t n) { - void * data = NULL; - const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); - if (result != 0) { - WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; - } - - return data; -} - -void wsp_ggml_metal_host_free(void * data) { - free(data); -} - -void wsp_ggml_metal_set_n_cb(struct wsp_ggml_metal_context * ctx, int n_cb) { - ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS); -} +// temporarily defined here for compatibility between ggml-backend and the old API -int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) { - return ctx->concur_list_len; -} +struct wsp_ggml_backend_metal_buffer { + void * data; + size_t size; -int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) { - return ctx->concur_list; -} + id metal; +}; -// temporarily defined here for compatibility between ggml-backend and the old API struct wsp_ggml_backend_metal_buffer_context { - void * data; + void * all_data; + size_t all_size; + bool owned; - id metal; + // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap + int n_buffers; + struct wsp_ggml_backend_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS]; }; // finds the Metal buffer that contains the tensor data on the GPU device @@ -610,17 +562,29 @@ int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) { const int64_t tsize = wsp_ggml_nbytes(t); + wsp_ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + // compatibility with ggml-backend - if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) { - struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context; + if (buffer && buffer->buft == wsp_ggml_backend_metal_buffer_type()) { + struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) buffer->context; - const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data; + // find the view that contains the tensor fully + for (int i = 0; i < buf_ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; - WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size); + //WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); + if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { + *offs = (size_t) ioffs; - *offs = (size_t) ioffs; + //WSP_GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); - return buf_ctx->metal; + return buf_ctx->buffers[i].metal; + } + } + + WSP_GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); + + return nil; } // find the view that contains the tensor fully @@ -642,210 +606,7 @@ int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) { return nil; } -bool wsp_ggml_metal_add_buffer( - struct wsp_ggml_metal_context * ctx, - const char * name, - void * data, - size_t size, - size_t max_size) { - if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) { - WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__); - return false; - } - - if (data) { - // verify that the buffer does not overlap with any of the existing buffers - for (int i = 0; i < ctx->n_buffers; ++i) { - const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data; - - if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { - WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); - return false; - } - } - - const size_t size_page = sysconf(_SC_PAGESIZE); - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= ctx->device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].name = name; - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; - - ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0); - return false; - } - - WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = ctx->device.maxBufferLength - size_ovlp; - const size_t size_view = ctx->device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].name = name; - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - - ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); - return false; - } - - WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); - if (i + size_step < size) { - WSP_GGML_METAL_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - -#if TARGET_OS_OSX - WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", - ctx->device.currentAllocatedSize / 1024.0 / 1024.0, - ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - - if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { - WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); - } else { - WSP_GGML_METAL_LOG_INFO("\n"); - } -#else - WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); -#endif - } - - return true; -} - -void wsp_ggml_metal_set_tensor( - struct wsp_ggml_metal_context * ctx, - struct wsp_ggml_tensor * t) { - size_t offs; - id id_dst = wsp_ggml_metal_get_buffer(ctx, t, &offs); - - memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, wsp_ggml_nbytes(t)); -} - -void wsp_ggml_metal_get_tensor( - struct wsp_ggml_metal_context * ctx, - struct wsp_ggml_tensor * t) { - size_t offs; - id id_src = wsp_ggml_metal_get_buffer(ctx, t, &offs); - - memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), wsp_ggml_nbytes(t)); -} - -void wsp_ggml_metal_graph_find_concurrency( - struct wsp_ggml_metal_context * ctx, - struct wsp_ggml_cgraph * gf, bool check_mem) { - int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time - int nodes_unused[WSP_GGML_MAX_CONCUR]; - - for (int i = 0; i < WSP_GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; } - for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } - ctx->concur_list_len = 0; - - int n_left = gf->n_nodes; - int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list - int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos - - while (n_left > 0) { - // number of nodes at a layer (that can be issued concurrently) - int concurrency = 0; - for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { - if (nodes_unused[i]) { - // if the requirements for gf->nodes[i] are satisfied - int exe_flag = 1; - - // scan all srcs - for (int src_ind = 0; src_ind < WSP_GGML_MAX_SRC; src_ind++) { - struct wsp_ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; - if (src_cur) { - // if is leaf nodes it's satisfied. - // TODO: wsp_ggml_is_leaf() - if (src_cur->op == WSP_GGML_OP_NONE && src_cur->grad == NULL) { - continue; - } - - // otherwise this src should be the output from previous nodes. - int is_found = 0; - - // scan 2*search_depth back because we inserted barrier. - //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { - for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) { - if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) { - is_found = 1; - break; - } - } - if (is_found == 0) { - exe_flag = 0; - break; - } - } - } - if (exe_flag && check_mem) { - // check if nodes[i]'s data will be overwritten by a node before nodes[i]. - // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] - int64_t data_start = (int64_t) gf->nodes[i]->data; - int64_t length = (int64_t) wsp_ggml_nbytes(gf->nodes[i]); - for (int j = n_start; j < i; j++) { - if (nodes_unused[j] && gf->nodes[j]->op != WSP_GGML_OP_RESHAPE \ - && gf->nodes[j]->op != WSP_GGML_OP_VIEW \ - && gf->nodes[j]->op != WSP_GGML_OP_TRANSPOSE \ - && gf->nodes[j]->op != WSP_GGML_OP_PERMUTE) { - if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ - ((int64_t)gf->nodes[j]->data) + (int64_t) wsp_ggml_nbytes(gf->nodes[j]) <= data_start) { - continue; - } - - exe_flag = 0; - } - } - } - if (exe_flag) { - ctx->concur_list[level_pos + concurrency] = i; - nodes_unused[i] = 0; - concurrency++; - ctx->concur_list_len++; - } - } - } - n_left -= concurrency; - // adding a barrier different layer - ctx->concur_list[level_pos + concurrency] = -1; - ctx->concur_list_len++; - // jump all sorted nodes at nodes_bak - while (!nodes_unused[n_start]) { - n_start++; - } - level_pos += concurrency + 1; - } - - if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) { - WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__); - } -} - -static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) { +static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_metal_context * ctx, const struct wsp_ggml_tensor * op) { switch (op->op) { case WSP_GGML_OP_UNARY: switch (wsp_ggml_get_unary_op(op)) { @@ -871,9 +632,11 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) { case WSP_GGML_OP_SCALE: case WSP_GGML_OP_SQR: case WSP_GGML_OP_SUM_ROWS: + return true; case WSP_GGML_OP_SOFT_MAX: case WSP_GGML_OP_RMS_NORM: case WSP_GGML_OP_GROUP_NORM: + return ctx->support_simdgroup_reduction; case WSP_GGML_OP_NORM: case WSP_GGML_OP_ALIBI: case WSP_GGML_OP_ROPE: @@ -882,9 +645,10 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) { case WSP_GGML_OP_PAD: case WSP_GGML_OP_ARGSORT: case WSP_GGML_OP_LEAKY_RELU: + return true; case WSP_GGML_OP_MUL_MAT: case WSP_GGML_OP_MUL_MAT_ID: - return true; + return ctx->support_simdgroup_reduction; case WSP_GGML_OP_CPY: case WSP_GGML_OP_DUP: case WSP_GGML_OP_CONT: @@ -922,1433 +686,1559 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) { return false; } } -void wsp_ggml_metal_graph_compute( + +static bool wsp_ggml_metal_graph_compute( struct wsp_ggml_metal_context * ctx, struct wsp_ggml_cgraph * gf) { - @autoreleasepool { - // if there is ctx->concur_list, dispatch concurrently - // else fallback to serial dispatch MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - - const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= WSP_GGML_MAX_CONCUR; - - const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; - edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; + edesc.dispatchType = MTLDispatchTypeSerial; // create multiple command buffers and enqueue them // then, we encode the graph into the command buffers in parallel + const int n_nodes = gf->n_nodes; const int n_cb = ctx->n_cb; + const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - for (int i = 0; i < n_cb; ++i) { - ctx->command_buffers[i] = [ctx->queue commandBuffer]; + id command_buffer_builder[n_cb]; + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + command_buffer_builder[cb_idx] = command_buffer; // enqueue the command buffers in order to specify their execution order - [ctx->command_buffers[i] enqueue]; - - ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; + [command_buffer enqueue]; } + const id *command_buffers = command_buffer_builder; - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - - dispatch_async(ctx->d_queue, ^{ - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_dst = 0; - - id command_buffer = ctx->command_buffers[cb_idx]; - id encoder = ctx->command_encoders[cb_idx]; - - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); - - for (int ind = node_start; ind < node_end; ++ind) { - const int i = has_concur ? ctx->concur_list[ind] : ind; - - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } - - //WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op)); - - struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct wsp_ggml_tensor * dst = gf->nodes[i]; - - switch (dst->op) { - case WSP_GGML_OP_NONE: - case WSP_GGML_OP_RESHAPE: - case WSP_GGML_OP_VIEW: - case WSP_GGML_OP_TRANSPOSE: - case WSP_GGML_OP_PERMUTE: - { - // noop -> next node - } continue; - default: - { - } break; - } - - if (!wsp_ggml_metal_supports_op(dst)) { - WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst)); - WSP_GGML_ASSERT(!"unsupported op"); - } - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT; - const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT; - const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT; - - id id_src0 = src0 ? wsp_ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; - id id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; - id id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; - - //WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op)); - //if (src0) { - // WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02, - // wsp_ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12, - // wsp_ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case WSP_GGML_OP_CONCAT: - { - const int64_t nb = ne00; - - [encoder setComputePipelineState:ctx->pipeline_concat]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_ADD: - case WSP_GGML_OP_MUL: - case WSP_GGML_OP_DIV: - { - const size_t offs = 0; - - bool bcast_row = false; - - int64_t nb = ne00; - - id pipeline = nil; - - if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { + const int cb_idx = iter; - // src1 is a row - WSP_GGML_ASSERT(ne11 == 1); + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_dst = 0; - nb = ne00 / 4; - switch (dst->op) { - case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break; - case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break; - case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break; - default: WSP_GGML_ASSERT(false); - } + id command_buffer = command_buffers[cb_idx]; + id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - bcast_row = true; - } else { - switch (dst->op) { - case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add; break; - case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break; - case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div; break; - default: WSP_GGML_ASSERT(false); - } - } + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; - - if (bcast_row) { - const int64_t n = wsp_ggml_nelements(dst)/4; + for (int i = node_start; i < node_end; ++i) { + if (i == -1) { + [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; + continue; + } - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + //WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op)); + + struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0]; + struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct wsp_ggml_tensor * dst = gf->nodes[i]; + + switch (dst->op) { + case WSP_GGML_OP_NONE: + case WSP_GGML_OP_RESHAPE: + case WSP_GGML_OP_VIEW: + case WSP_GGML_OP_TRANSPOSE: + case WSP_GGML_OP_PERMUTE: + { + // noop -> next node + } continue; + default: + { + } break; + } - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case WSP_GGML_OP_ACC: - { - WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32); + if (!wsp_ggml_metal_supports_op(ctx, dst)) { + WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst)); + WSP_GGML_ASSERT(!"unsupported op"); + } - WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); - WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1)); - - const size_t pnb1 = ((int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((int32_t *) dst->op_params)[2]; - const size_t offs = ((int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const int nth = MIN(1024, ne00); - - [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } +#ifndef WSP_GGML_METAL_NDEBUG + [encoder pushDebugGroup:[NSString stringWithCString:wsp_ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; +#endif - [encoder setComputePipelineState:ctx->pipeline_add]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_SCALE: - { + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT; + const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT; + const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT; + + id id_src0 = src0 ? wsp_ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; + id id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; + id id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; + + //WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op)); + //if (src0) { + // WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02, + // wsp_ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12, + // wsp_ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case WSP_GGML_OP_CONCAT: + { + const int64_t nb = ne00; + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_ADD: + case WSP_GGML_OP_MUL: + case WSP_GGML_OP_DIV: + { + const size_t offs = 0; + + bool bcast_row = false; + + int64_t nb = ne00; + + id pipeline = nil; + + if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); - const float scale = *(const float *) src1->data; + // src1 is a row + WSP_GGML_ASSERT(ne11 == 1); - int64_t n = wsp_ggml_nelements(dst); + nb = ne00 / 4; + switch (dst->op) { + case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: WSP_GGML_ASSERT(false); + } - if (n % 4 == 0) { - n /= 4; - [encoder setComputePipelineState:ctx->pipeline_scale_4]; - } else { - [encoder setComputePipelineState:ctx->pipeline_scale]; + bcast_row = true; + } else { + switch (dst->op) { + case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: WSP_GGML_ASSERT(false); } + } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + if (bcast_row) { + const int64_t n = wsp_ggml_nelements(dst)/4; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_OP_UNARY: - switch (wsp_ggml_get_unary_op(gf->nodes[i])) { - case WSP_GGML_UNARY_OP_TANH: - { - [encoder setComputePipelineState:ctx->pipeline_tanh]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - const int64_t n = wsp_ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case WSP_GGML_OP_ACC: + { + WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_UNARY_OP_RELU: - { - [encoder setComputePipelineState:ctx->pipeline_relu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1)); - const int64_t n = wsp_ggml_nelements(dst); + const size_t pnb1 = ((int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((int32_t *) dst->op_params)[2]; + const size_t offs = ((int32_t *) dst->op_params)[3]; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_UNARY_OP_GELU: - { - [encoder setComputePipelineState:ctx->pipeline_gelu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - const int64_t n = wsp_ggml_nelements(dst); - WSP_GGML_ASSERT(n % 4 == 0); + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_UNARY_OP_GELU_QUICK: - { - [encoder setComputePipelineState:ctx->pipeline_gelu_quick]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + const id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - const int64_t n = wsp_ggml_nelements(dst); - WSP_GGML_ASSERT(n % 4 == 0); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_UNARY_OP_SILU: - { - [encoder setComputePipelineState:ctx->pipeline_silu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - const int64_t n = wsp_ggml_nelements(dst); - WSP_GGML_ASSERT(n % 4 == 0); + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); - WSP_GGML_ASSERT(false); - } - } break; - case WSP_GGML_OP_SQR: - { - WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + const id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_SCALE: + { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + + const float scale = *(const float *) dst->op_params; + + int64_t n = wsp_ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } - [encoder setComputePipelineState:ctx->pipeline_sqr]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = wsp_ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_OP_SUM_ROWS: - { - WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type)); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_OP_UNARY: + switch (wsp_ggml_get_unary_op(gf->nodes[i])) { + case WSP_GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_TANH].pipeline; - [encoder setComputePipelineState:ctx->pipeline_sum_rows]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_OP_SOFT_MAX: - { - int nth = 32; // SIMD width - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth < 256) { - nth *= 2; - } - [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; - } else { - while (nth < ne00 && nth < 1024) { - nth *= 2; - } - [encoder setComputePipelineState:ctx->pipeline_soft_max]; - } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const float scale = ((float *) dst->op_params)[0]; + const int64_t n = wsp_ggml_nelements(dst); - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - - if (ne00%8 == 0) { - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8]; - } else { - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RELU].pipeline; - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = wsp_ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_UNARY_OP_GELU: + { + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = wsp_ggml_nelements(dst); + WSP_GGML_ASSERT(n % 4 == 0); + + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_UNARY_OP_GELU_QUICK: + { + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = wsp_ggml_nelements(dst); + WSP_GGML_ASSERT(n % 4 == 0); + + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_UNARY_OP_SILU: + { + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SILU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = wsp_ggml_nelements(dst); + WSP_GGML_ASSERT(n % 4 == 0); + + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + default: + { + WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); + WSP_GGML_ASSERT(false); } - } break; - case WSP_GGML_OP_MUL_MAT: - { - WSP_GGML_ASSERT(ne00 == ne10); + } break; + case WSP_GGML_OP_SQR: + { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = wsp_ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_OP_SUM_ROWS: + { + WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_OP_SOFT_MAX: + { + int nth = 32; // SIMD width + + id pipeline = nil; + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth < 256) { + nth *= 2; + } + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; + } else { + while (nth < ne00 && nth < 1024) { + nth *= 2; + } + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; + } - // TODO: assert that dim2 and dim3 are contiguous - WSP_GGML_ASSERT(ne12 % ne02 == 0); - WSP_GGML_ASSERT(ne13 % ne03 == 0); + const float scale = ((float *) dst->op_params)[0]; - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((int32_t *)(dst->op_params))[0]; + + id pipeline = nil; + + if (ne00%8 == 0) { + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case WSP_GGML_OP_MUL_MAT: + { + WSP_GGML_ASSERT(ne00 == ne10); - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; + // TODO: assert that dim2 and dim3 are contiguous + WSP_GGML_ASSERT(ne12 % ne02 == 0); + WSP_GGML_ASSERT(ne13 % ne03 == 0); + + const uint r2 = ne12/ne02; + const uint r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 1; #if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach - if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { - switch (src0t) { - case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break; - case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break; - case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break; - case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break; - case WSP_GGML_TYPE_Q4_0: - case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break; - case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break; - case WSP_GGML_TYPE_Q5_0: // not tested yet - case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet - case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break; - case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break; - default: ne11_mm_min = 1; break; - } + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach + if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { + switch (src0t) { + case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break; + case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break; + case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break; + case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break; + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break; + case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break; + case WSP_GGML_TYPE_Q5_0: // not tested yet + case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet + case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break; + case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break; + default: ne11_mm_min = 1; break; } + } #endif - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !wsp_ggml_is_transposed(src0) && - !wsp_ggml_is_transposed(src1) && - src1t == WSP_GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - switch (src0->type) { - case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; - case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; - case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; - case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; - case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; - case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; - case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; - case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; - case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; - case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; - case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; - case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; - default: WSP_GGML_ASSERT(false && "MUL MAT-MAT not implemented"); - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - // use custom matrix x vector kernel - switch (src0t) { - case WSP_GGML_TYPE_F32: - { - WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; - nrows = 4; - } break; - case WSP_GGML_TYPE_F16: - { - nth0 = 32; - nth1 = 1; - if (src1t == WSP_GGML_TYPE_F32) { - if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; - nrows = ne11; - } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; - nrows = 4; - } + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + !wsp_ggml_is_transposed(src0) && + !wsp_ggml_is_transposed(src1) && + src1t == WSP_GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + switch (src0->type) { + case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; + default: WSP_GGML_ASSERT(false && "MUL MAT-MAT not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; + nrows = 4; + } break; + case WSP_GGML_TYPE_F16: + { + nth0 = 32; + nth1 = 1; + if (src1t == WSP_GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; + nrows = ne11; } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16]; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; nrows = 4; } - } break; - case WSP_GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; - } break; - case WSP_GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; - } break; - case WSP_GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; - } break; - case WSP_GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; - } break; - case WSP_GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; - } break; - case WSP_GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; - } break; - case WSP_GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; - } break; - case WSP_GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; - } break; - case WSP_GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; - } break; - case WSP_GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; - } break; - default: - { - WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); - WSP_GGML_ASSERT(false && "not implemented"); + } else { + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; + nrows = 4; } - }; + } break; + case WSP_GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; + } break; + case WSP_GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; + } break; + default: + { + WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); + WSP_GGML_ASSERT(false && "not implemented"); + } + }; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; - - if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || - src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 || - src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == WSP_GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == WSP_GGML_TYPE_Q3_K) { -#ifdef WSP_GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#else - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#endif - } - else if (src0t == WSP_GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == WSP_GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } + if (wsp_ggml_is_quantized(src0t)) { + WSP_GGML_ASSERT(ne00 >= nth0*nth1); } - } break; - case WSP_GGML_OP_MUL_MAT_ID: - { - //WSP_GGML_ASSERT(ne00 == ne10); - //WSP_GGML_ASSERT(ne03 == ne13); - - WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32); - - const int n_as = ((int32_t *) dst->op_params)[1]; - - // TODO: make this more general - WSP_GGML_ASSERT(n_as <= 8); - - struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2]; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; - const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23); - - const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t); - - WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2)); - WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1)); - - WSP_GGML_ASSERT(ne20 % 32 == 0); - // !!!!!!!!! TODO: this assert is probably required but not sure! - //WSP_GGML_ASSERT(ne20 >= 64); - WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); - - const uint r2 = ne12/ne22; - const uint r3 = ne13/ne23; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; - - const int idx = ((int32_t *) dst->op_params)[0]; - - // batch size - WSP_GGML_ASSERT(ne01 == ne11); - - const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) { - switch (src2->type) { - case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; - case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; - case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break; - case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break; - case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break; - case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break; - case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break; - case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break; - case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break; - case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break; - case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break; - case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break; - default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented"); - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; - // TODO: how to make this an array? read Metal docs - for (int j = 0; j < n_as; ++j) { - struct wsp_ggml_tensor * src_cur = dst->src[2 + j]; - - size_t offs_src_cur = 0; - id id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); - - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; - } - - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - - // TODO: processing one row at a time (ne11 -> 1) is not efficient - [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - // use custom matrix x vector kernel - switch (src2t) { - case WSP_GGML_TYPE_F32: - { - WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32]; - } break; - case WSP_GGML_TYPE_F16: - { - WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32]; - } break; - case WSP_GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32]; - } break; - case WSP_GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32]; - } break; - case WSP_GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32]; - } break; - case WSP_GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32]; - } break; - case WSP_GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32]; - } break; - case WSP_GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32]; - } break; - case WSP_GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32]; - } break; - case WSP_GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32]; - } break; - case WSP_GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32]; - } break; - case WSP_GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32]; - } break; - default: - { - WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); - WSP_GGML_ASSERT(false && "not implemented"); - } - }; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:20]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; - // TODO: how to make this an array? read Metal docs - for (int j = 0; j < n_as; ++j) { - struct wsp_ggml_tensor * src_cur = dst->src[2 + j]; - - size_t offs_src_cur = 0; - id id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); - - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; - } - - if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 || - src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 || - src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src2t == WSP_GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src2t == WSP_GGML_TYPE_Q3_K) { + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; + + if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || + src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 || + src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == WSP_GGML_TYPE_IQ2_XXS || src0t == WSP_GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == WSP_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == WSP_GGML_TYPE_Q3_K) { #ifdef WSP_GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif - } - else if (src2t == WSP_GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src2t == WSP_GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } } - } break; - case WSP_GGML_OP_GET_ROWS: - { - switch (src0->type) { - case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; - case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; - case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; - case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; - case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; - case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; - case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; - case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; - case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; - case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; - case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; - default: WSP_GGML_ASSERT(false && "not implemented"); + else if (src0t == WSP_GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case WSP_GGML_OP_RMS_NORM: - { - WSP_GGML_ASSERT(ne00 % 4 == 0); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; + else if (src0t == WSP_GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (ne11 + nrows - 1)/nrows; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + } + } break; + case WSP_GGML_OP_MUL_MAT_ID: + { + //WSP_GGML_ASSERT(ne00 == ne10); + //WSP_GGML_ASSERT(ne03 == ne13); - [encoder setComputePipelineState:ctx->pipeline_rms_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = wsp_ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_GROUP_NORM: - { - WSP_GGML_ASSERT(ne00 % 4 == 0); - - //float eps; - //memcpy(&eps, dst->op_params, sizeof(float)); - - const float eps = 1e-6f; // TODO: temporarily hardcoded - - const int32_t n_groups = ((int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - [encoder setComputePipelineState:ctx->pipeline_group_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_NORM: - { - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(256, ne00); - - [encoder setComputePipelineState:ctx->pipeline_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0]; - - const int64_t nrows = wsp_ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_ALIBI: - { - WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32)); - - const int nth = MIN(1024, ne00); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32); - [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; - [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; + const int n_as = ((int32_t *) dst->op_params)[1]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_ROPE: - { - WSP_GGML_ASSERT(ne10 == ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + // TODO: make this more general + WSP_GGML_ASSERT(n_as <= 8); - switch (src0->type) { - case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; - case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; - default: WSP_GGML_ASSERT(false); - }; + // max size of the src1ids array in the kernel stack + WSP_GGML_ASSERT(ne11 <= 512); - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; - [encoder setBytes:&mode length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_IM2COL: - { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16); + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; + const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23); - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23); - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; + const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t); - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; + WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2)); + WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1)); - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; + WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); - const int32_t CHW = IC * KH * KW; + const uint r2 = ne12/ne22; + const uint r3 = ne13/ne23; - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = n_as; - switch (src0->type) { - case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break; - case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break; - default: WSP_GGML_ASSERT(false); - }; + const int idx = ((int32_t *) dst->op_params)[0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } break; - case WSP_GGML_OP_UPSCALE: - { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); - - const int sf = dst->op_params[0]; - - [encoder setComputePipelineState:ctx->pipeline_upscale_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_PAD: - { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); - - [encoder setComputePipelineState:ctx->pipeline_pad_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case WSP_GGML_OP_ARGSORT: - { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32); - - const int nrows = wsp_ggml_nrows(src0); - - enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0]; - - switch (order) { - case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break; - case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break; - default: WSP_GGML_ASSERT(false); - }; + // batch size + WSP_GGML_ASSERT(ne01 == ne11); - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne20 % 32 == 0 && ne20 >= 64 && + ne11 > ne11_mm_min) { - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; - } break; - case WSP_GGML_OP_LEAKY_RELU: - { - WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); + id pipeline = nil; - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); + switch (src2->type) { + case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; + default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented"); + } - [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; + // TODO: how to make this an array? read Metal docs + for (int j = 0; j < 8; ++j) { + // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 + struct wsp_ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; + + size_t offs_src_cur = 0; + id id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); + + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; + } - const int64_t n = wsp_ggml_nelements(dst); + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case WSP_GGML_OP_DUP: - case WSP_GGML_OP_CPY: - case WSP_GGML_OP_CONT: - { - WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0); + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type)); + id pipeline = nil; - switch (src0t) { + // use custom matrix x vector kernel + switch (src2t) { case WSP_GGML_TYPE_F32: { - WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; - case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; - case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break; - case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break; - case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break; - //case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break; - //case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break; - default: WSP_GGML_ASSERT(false && "not implemented"); - }; + WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case WSP_GGML_TYPE_F16: { - switch (dstt) { - case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; - case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break; - default: WSP_GGML_ASSERT(false && "not implemented"); - }; + WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case WSP_GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; } break; - default: WSP_GGML_ASSERT(false && "not implemented"); + case WSP_GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + default: + { + WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); + WSP_GGML_ASSERT(false && "not implemented"); + } + }; + + if (wsp_ggml_is_quantized(src2t)) { + WSP_GGML_ASSERT(ne20 >= nth0*nth1); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + const int64_t _ne1 = 1; // kernels needs a reference in constant memory - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); - WSP_GGML_ASSERT(false); + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:20]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; + // TODO: how to make this an array? read Metal docs + for (int j = 0; j < 8; ++j) { + // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 + struct wsp_ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; + + size_t offs_src_cur = 0; + id id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); + + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; + } + + if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 || + src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 || + src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == WSP_GGML_TYPE_IQ2_XXS || src2t == WSP_GGML_TYPE_IQ2_XS) { + const int mem_size = src2t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == WSP_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == WSP_GGML_TYPE_Q3_K) { +#ifdef WSP_GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } + else if (src2t == WSP_GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == WSP_GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; + [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case WSP_GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case WSP_GGML_TYPE_I32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: WSP_GGML_ASSERT(false && "not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case WSP_GGML_OP_RMS_NORM: + { + WSP_GGML_ASSERT(ne00 % 4 == 0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; } - } - } - if (encoder != nil) { - [encoder endEncoding]; - encoder = nil; + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = wsp_ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_GROUP_NORM: + { + WSP_GGML_ASSERT(ne00 % 4 == 0); + + //float eps; + //memcpy(&eps, dst->op_params, sizeof(float)); + + const float eps = 1e-6f; // TODO: temporarily hardcoded + + const int32_t n_groups = ((int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_NORM: + { + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int nth = MIN(256, ne00); + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + + const int64_t nrows = wsp_ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_ALIBI: + { + WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32)); + + const int nth = MIN(1024, ne00); + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; + [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_ROPE: + { + WSP_GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + id pipeline = nil; + + switch (src0->type) { + case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; + case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; + default: WSP_GGML_ASSERT(false); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; + [encoder setBytes:&mode length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_IM2COL: + { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = nil; + + switch (src0->type) { + case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break; + case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; + default: WSP_GGML_ASSERT(false); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; + case WSP_GGML_OP_UPSCALE: + { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); + + const int sf = dst->op_params[0]; + + const id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_PAD: + { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case WSP_GGML_OP_ARGSORT: + { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32); + + const int nrows = wsp_ggml_nrows(src0); + + enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0]; + + id pipeline = nil; + + switch (order) { + case WSP_GGML_SORT_ASC: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case WSP_GGML_SORT_DESC: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: WSP_GGML_ASSERT(false); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; + } break; + case WSP_GGML_OP_LEAKY_RELU: + { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = wsp_ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case WSP_GGML_OP_DUP: + case WSP_GGML_OP_CPY: + case WSP_GGML_OP_CONT: + { + WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type)); + + id pipeline = nil; + + switch (src0t) { + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + //case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + //case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + default: WSP_GGML_ASSERT(false && "not implemented"); + }; + } break; + case WSP_GGML_TYPE_F16: + { + switch (dstt) { + case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + default: WSP_GGML_ASSERT(false && "not implemented"); + }; + } break; + default: WSP_GGML_ASSERT(false && "not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); + WSP_GGML_ASSERT(false); + } } - [command_buffer commit]; - }); - } +#ifndef WSP_GGML_METAL_NDEBUG + [encoder popDebugGroup]; +#endif + } + + [encoder endEncoding]; - // wait for all threads to finish - dispatch_barrier_sync(ctx->d_queue, ^{}); + [command_buffer commit]; + }); - // check status of command buffers + // Wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) - for (int i = 0; i < n_cb; i++) { - [ctx->command_buffers[i] waitUntilCompleted]; - MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; + for (int i = 0; i < n_cb; ++i) { + id command_buffer = command_buffers[i]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; if (status != MTLCommandBufferStatusCompleted) { WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - WSP_GGML_ASSERT(false); + return false; } } - } + return true; } //////////////////////////////////////////////////////////////////////////////// // backend interface +// default buffer static id g_backend_device = nil; static int g_backend_device_ref_count = 0; @@ -2372,64 +2262,98 @@ static void wsp_ggml_backend_metal_free_device(void) { } } -static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { - struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context; +WSP_GGML_CALL static const char * wsp_ggml_backend_metal_buffer_get_name(wsp_ggml_backend_buffer_t buffer) { + return "Metal"; - return ctx->data; + UNUSED(buffer); } -static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { +WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context; wsp_ggml_backend_metal_free_device(); - free(ctx->data); - free(ctx); + if (ctx->owned) { + free(ctx->all_data); + } - UNUSED(buffer); + free(ctx); } -static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds"); - WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); +WSP_GGML_CALL static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { + struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context; + + return ctx->all_data; +} +WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); UNUSED(buffer); } -static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { - WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds"); - WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - +WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { memcpy(data, (const char *)tensor->data + offset, size); UNUSED(buffer); } -static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { - wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src)); +WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + if (wsp_ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, wsp_ggml_nbytes(src)); + return true; + } + return false; UNUSED(buffer); } -static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { - wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src)); +WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) { + struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context; - UNUSED(buffer); + memset(ctx->all_data, value, ctx->all_size); } -static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = { +static struct wsp_ggml_backend_buffer_i wsp_ggml_backend_metal_buffer_i = { + /* .get_name = */ wsp_ggml_backend_metal_buffer_get_name, /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer, /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base, /* .init_tensor = */ NULL, /* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor, /* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor, - /* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from, - /* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to, + /* .cpy_tensor = */ wsp_ggml_backend_metal_buffer_cpy_tensor, + /* .clear = */ wsp_ggml_backend_metal_buffer_clear, + /* .reset = */ NULL, }; -static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) { +// default buffer type + +WSP_GGML_CALL static const char * wsp_ggml_backend_metal_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) { + return "Metal"; + + UNUSED(buft); +} + +static void wsp_ggml_backend_metal_log_allocated_size(id device) { +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + device.currentAllocatedSize / 1024.0 / 1024.0, + device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { + WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); + } else { + WSP_GGML_METAL_LOG_INFO("\n"); + } + } else { + WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + } +#endif + UNUSED(device); +} + +WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) { struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context)); const size_t size_page = sysconf(_SC_PAGESIZE); @@ -2439,33 +2363,59 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer size_aligned += (size_page - (size_aligned % size_page)); } - ctx->data = wsp_ggml_metal_host_malloc(size); - ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data + id device = wsp_ggml_backend_metal_get_device(); + + ctx->all_data = wsp_ggml_metal_host_malloc(size_aligned); + ctx->all_size = size_aligned; + ctx->owned = true; + ctx->n_buffers = 1; + + ctx->buffers[0].data = ctx->all_data; + ctx->buffers[0].size = size; + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size); + if (ctx->buffers[0].metal == nil) { + WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(ctx); + wsp_ggml_backend_metal_free_device(); + return NULL; + } + + WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + wsp_ggml_backend_metal_log_allocated_size(device); + + return wsp_ggml_backend_buffer_init(buft, wsp_ggml_backend_metal_buffer_i, ctx, size); } -static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) { +WSP_GGML_CALL static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) { return 32; UNUSED(buft); } -static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) { +WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) { return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend); - WSP_GGML_UNUSED(buft); + UNUSED(buft); +} + +WSP_GGML_CALL static bool wsp_ggml_backend_metal_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) { + return true; + + UNUSED(buft); } -wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) { +WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) { static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = { /* .iface = */ { + /* .get_name = */ wsp_ggml_backend_metal_buffer_type_get_name, /* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer, /* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment, /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes /* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend, + /* .is_host = */ wsp_ggml_backend_metal_buffer_type_is_host, }, /* .context = */ NULL, }; @@ -2473,67 +2423,134 @@ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) { return &wsp_ggml_backend_buffer_type_metal; } -static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) { +// buffer from ptr + +WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { + struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context)); + + ctx->all_data = data; + ctx->all_size = size; + ctx->owned = false; + ctx->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) data % size_page; + data = (void *) ((char *) data - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + id device = wsp_ggml_backend_metal_get_device(); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].data = data; + ctx->buffers[ctx->n_buffers].size = size; + + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } + + WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = device.maxBufferLength - size_ovlp; + const size_t size_view = device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } + + WSP_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + if (i + size_step < size) { + WSP_GGML_METAL_LOG_INFO("\n"); + } + + ++ctx->n_buffers; + } + } + + wsp_ggml_backend_metal_log_allocated_size(device); + + return wsp_ggml_backend_buffer_init(wsp_ggml_backend_metal_buffer_type(), wsp_ggml_backend_metal_buffer_i, ctx, size); +} + +// backend + +WSP_GGML_CALL static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) { return "Metal"; UNUSED(backend); } -static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) { +WSP_GGML_CALL static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) { struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context; wsp_ggml_metal_free(ctx); free(backend); } -static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) { - UNUSED(backend); -} - -static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) { +WSP_GGML_CALL static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) { return wsp_ggml_backend_metal_buffer_type(); UNUSED(backend); } -static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { +WSP_GGML_CALL static bool wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context; - wsp_ggml_metal_graph_compute(metal_ctx, cgraph); + return wsp_ggml_metal_graph_compute(metal_ctx, cgraph); } -static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { - return wsp_ggml_metal_supports_op(op); +WSP_GGML_CALL static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { + struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context; - UNUSED(backend); + return wsp_ggml_metal_supports_op(metal_ctx, op); } -static struct wsp_ggml_backend_i metal_backend_i = { +static struct wsp_ggml_backend_i wsp_ggml_backend_metal_i = { /* .get_name = */ wsp_ggml_backend_metal_name, /* .free = */ wsp_ggml_backend_metal_free, /* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .cpy_tensor_from_async = */ NULL, - /* .cpy_tensor_to_async = */ NULL, - /* .synchronize = */ wsp_ggml_backend_metal_synchronize, - /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute, /* .supports_op = */ wsp_ggml_backend_metal_supports_op, }; -// TODO: make a common log callback for all backends in ggml-backend -static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) { - fprintf(stderr, "%s", msg); - - UNUSED(level); - UNUSED(user_data); +void wsp_ggml_backend_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) { + wsp_ggml_metal_log_callback = log_callback; + wsp_ggml_metal_log_user_data = user_data; } wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) { - wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL); - struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS); if (ctx == NULL) { @@ -2543,7 +2560,7 @@ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) { wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend)); *metal_backend = (struct wsp_ggml_backend) { - /* .interface = */ metal_backend_i, + /* .interface = */ wsp_ggml_backend_metal_i, /* .context = */ ctx, }; @@ -2551,7 +2568,7 @@ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) { } bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) { - return backend->iface.get_name == wsp_ggml_backend_metal_name; + return backend && backend->iface.get_name == wsp_ggml_backend_metal_name; } void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) { @@ -2559,7 +2576,7 @@ void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) { struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context; - wsp_ggml_metal_set_n_cb(ctx, n_cb); + ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS); } bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) { @@ -2570,9 +2587,9 @@ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int fami return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; } -wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning +WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning -wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) { +WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) { return wsp_ggml_backend_metal_init(); WSP_GGML_UNUSED(params); diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c index 7ca5b7d..3a5d0cc 100644 --- a/cpp/ggml-quants.c +++ b/cpp/ggml-quants.c @@ -5,6 +5,8 @@ #include #include #include +#include // for qsort +#include // for WSP_GGML_ASSERT #ifdef __ARM_NEON @@ -272,10 +274,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 // vaddvq_s16 // vpaddq_s16 +// vpaddq_s32 // vaddvq_s32 // vaddvq_f32 // vmaxvq_f32 // vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 inline static int32_t vaddvq_s16(int16x8_t v) { return @@ -291,6 +296,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { return vcombine_s16(a0, b0); } +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + inline static int32_t vaddvq_s32(int32x4_t v) { return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); } @@ -316,6 +327,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { return res; } +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + // vld1q_s16_x2 // vld1q_u8_x2 // vld1q_u8_x4 @@ -407,6 +440,22 @@ inline static wsp_ggml_int8x16x4_t wsp_ggml_vld1q_s8_x4(const int8_t * ptr) { #define wsp_ggml_vld1q_s8_x4 vld1q_s8_x4 #endif + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t wsp_ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define wsp_ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + +#endif + #endif #if defined(__ARM_NEON) || defined(__wasm_simd128__) @@ -466,6 +515,7 @@ void wsp_quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { wsp_quantize_row_q4_0_reference(x, y, k); } + void wsp_quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { const int qk = QK4_1; @@ -1195,7 +1245,8 @@ static inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } -static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type, + const float * restrict qw) { float max = 0; float amax = 0; for (int i = 0; i < n; ++i) { @@ -1221,14 +1272,18 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * rmse_type = -rmse_type; return_early = true; } - int weight_type = rmse_type%2; float sumlx = 0; float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else for (int i = 0; i < n; ++i) { +#endif int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); L[i] = l + nmax; - float w = weight_type == 1 ? x[i] * x[i] : 1; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); sumlx += w*x[i]*l; suml2 += w*l*l; } @@ -1244,7 +1299,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * for (int i = 0; i < n; ++i) { int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); - float w = weight_type == 1 ? x[i] * x[i] : 1; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); sumlx += w*x[i]*l; suml2 += w*l*l; } @@ -1592,6 +1647,246 @@ size_t wsp_ggml_wsp_quantize_q2_K(const float * restrict src, void * restrict ds return (n/QK_K*sizeof(block_q2_K)); } +static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights ? weights[0] : x[0]*x[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights ? weights[i] : x[i]*x[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) { + min = 0; + } + if (max <= min) { + memset(L, 0, n); + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights ? weights[i] : x[i]*x[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) { + float max = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + if (!max) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = nmax / max; + for (int i = 0; i < n; ++i) { + L[i] = nearest_int(iscale * x[i]); + } + float scale = 1/iscale; + float best_mse = 0; + for (int i = 0; i < n; ++i) { + float diff = x[i] - scale*L[i]; + float w = quant_weights[i]; + best_mse += w*diff*diff; + } + for (int is = -4; is <= 4; ++is) { + if (is == 0) continue; + float iscale_is = (0.1f*is + nmax)/max; + float scale_is = 1/iscale_is; + float mse = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale_is*x[i]); + l = MIN(nmax, l); + float diff = x[i] - scale_is*l; + float w = quant_weights[i]; + mse += w*diff*diff; + } + if (mse < best_mse) { + best_mse = mse; + iscale = iscale_is; + } + } + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MIN(nmax, l); + L[i] = l; + float w = quant_weights[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = quant_weights[i]; + float slx = sumlx - w*x[i]*L[i]; + float sl2 = suml2 - w*L[i]*L[i]; + if (slx > 0 && sl2 > 0) { + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MIN(nmax, new_l); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + return sumlx / suml2; +} + +static void wsp_quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) { + WSP_GGML_ASSERT(quant_weights); + assert(k % QK_K == 0); + const int nb = k / QK_K; + const bool requantize = true; + + uint8_t L[QK_K]; + uint8_t Laux[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + float sw[QK_K/16]; + float weight[QK_K/16]; + uint8_t Ls[QK_K/16], Lm[QK_K/16]; + + for (int i = 0; i < nb; i++) { + memset(sw, 0, QK_K/16*sizeof(float)); + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = sumx2/QK_K; + for (int j = 0; j < QK_K/16; ++j) { + const float * restrict qw = quant_weights + QK_K * i + 16*j; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); + for (int l = 0; l < 16; ++l) sw[j] += weight[l]; + scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + } + + float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); + float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); + y[i].d = WSP_GGML_FP32_TO_FP16(dm); + y[i].dmin = WSP_GGML_FP32_TO_FP16(mm); + dm = WSP_GGML_FP16_TO_FP32(y[i].d); + mm = WSP_GGML_FP16_TO_FP32(y[i].dmin); + + for (int j = 0; j < QK_K/16; ++j) { + y[i].scales[j] = Ls[j] | (Lm[j] << 4); + } + + if (requantize) { + for (int j = 0; j < QK_K/16; ++j) { + const float d = dm * (y[i].scales[j] & 0xF); + if (!d) continue; + const float m = mm * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + m)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + } + +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + + } +} + +size_t wsp_quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q2_K, n_per_row); + if (!quant_weights) { + wsp_quantize_row_q2_K_reference(src, dst, nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + //========================= 3-bit (de)-quantization void wsp_quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { @@ -1805,6 +2100,112 @@ size_t wsp_ggml_wsp_quantize_q3_K(const float * restrict src, void * restrict ds return (n/QK_K*sizeof(block_q3_K)); } +static void wsp_quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) { +#if QK_K != 256 + (void)quant_weights; + wsp_quantize_row_q3_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int nb = n_per_row / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + float weight[16]; + float sw[QK_K / 16]; + int8_t Ls[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = 2*sumx2/QK_K; + + for (int j = 0; j < QK_K/16; ++j) { + if (quant_weights) { + const float * qw = quant_weights ? quant_weights + QK_K * i + 16*j : NULL; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]); + } else { + for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l]; + } + float sumw = 0; + for (int l = 0; l < 16; ++l) sumw += weight[l]; + sw[j] = sumw; + + scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight); + + } + + memset(y[i].scales, 0, 12); + + float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw); + for (int j = 0; j < QK_K/16; ++j) { + int l = Ls[j]; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = WSP_GGML_FP32_TO_FP16(d_block); + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +#endif +} + +size_t wsp_quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q3_K, n_per_row); + if (!quant_weights) { + wsp_quantize_row_q3_K_reference(src, dst, nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + // ====================== 4-bit (de)-quantization void wsp_quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { @@ -1970,6 +2371,108 @@ size_t wsp_ggml_wsp_quantize_q4_K(const float * restrict src, void * restrict ds return (n/QK_K*sizeof(block_q4_K)); } +static void wsp_quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + wsp_quantize_row_q4_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int nb = n_per_row / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + float weights[32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = sum_x2/QK_K; + float av_x = sqrtf(sigma2); + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + } + scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + //scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + + } +#endif +} + +size_t wsp_quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_K, n_per_row); + if (!quant_weights) { + wsp_quantize_row_q4_K_reference(src, dst, nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + // ====================== 5-bit (de)-quantization void wsp_quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { @@ -2065,7 +2568,7 @@ void wsp_quantize_row_q5_K_reference(const float * restrict x, block_q5_K * rest #else float max_scale = 0, amax = 0; for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL); float abs_scale = fabsf(scales[j]); if (abs_scale > amax) { amax = abs_scale; @@ -2159,21 +2662,138 @@ void wsp_dewsp_quantize_row_q5_K(const block_q5_K * restrict x, float * restrict } } -void wsp_quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { - assert(k % QK_K == 0); - block_q5_K * restrict y = vy; - wsp_quantize_row_q5_K_reference(x, y, k); -} +void wsp_quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q5_K * restrict y = vy; + wsp_quantize_row_q5_K_reference(x, y, k); +} + +size_t wsp_ggml_wsp_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; + wsp_quantize_row_q5_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q5_K)); +} + +static void wsp_quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + wsp_quantize_row_q5_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int nb = n_per_row / QK_K; + + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float weights[32]; + uint8_t Laux[32]; + + for (int i = 0; i < nb; i++) { + + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = sum_x2/QK_K; + float av_x = sqrtf(sigma2); + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + } + scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } + + x += QK_K; -size_t wsp_ggml_wsp_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - assert(k % QK_K == 0); - (void)hist; // TODO: collect histograms + } +#endif +} - for (int j = 0; j < n; j += k) { - block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; - wsp_quantize_row_q5_K_reference(src + j, y, k); +size_t wsp_quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_K, n_per_row); + if (!quant_weights) { + wsp_quantize_row_q5_K_reference(src, dst, nrow*n_per_row); } - return (n/QK_K*sizeof(block_q5_K)); + else { + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; } // ====================== 6-bit (de)-quantization @@ -2192,7 +2812,7 @@ void wsp_quantize_row_q6_K_reference(const float * restrict x, block_q6_K * rest for (int ib = 0; ib < QK_K/16; ++ib) { - const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); scales[ib] = scale; const float abs_scale = fabsf(scale); @@ -2324,6 +2944,569 @@ size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, i return (n/QK_K*sizeof(block_q6_K)); } +static void wsp_quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + wsp_quantize_row_q6_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int nb = n_per_row / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + //float weights[16]; + + for (int i = 0; i < nb; i++) { + + //float sum_x2 = 0; + //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j]; + //float sigma2 = sum_x2/QK_K; + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + float scale; + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 16*ib; + //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]); + //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights); + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw); + } else { + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); + } + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = WSP_GGML_FP32_TO_FP16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = WSP_GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = WSP_GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } + + x += QK_K; + + } +#endif +} + +size_t wsp_quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q6_K, n_per_row); + if (!quant_weights) { + wsp_quantize_row_q6_K_reference(src, dst, nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +static void wsp_quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) { + static_assert(QK4_0 == 32, "QK4_0 must be 32"); + + if (!quant_weights) { + wsp_quantize_row_q4_0_reference(x, y, n_per_row); + return; + } + + float weight[QK4_0]; + int8_t L[QK4_0]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int nb = n_per_row/QK4_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_0 * ib; + const float * qw = quant_weights + QK4_0 * ib; + for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); + y[ib].d = WSP_GGML_FP32_TO_FP16(d); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } +} + +size_t wsp_quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + if (!quant_weights) { + return wsp_ggml_wsp_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist); + } + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_0, n_per_row); + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +static void wsp_quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) { + static_assert(QK4_1 == 32, "QK4_1 must be 32"); + + if (!quant_weights) { + wsp_quantize_row_q4_1_reference(x, y, n_per_row); + return; + } + + float weight[QK4_1]; + uint8_t L[QK4_1], Laux[QK4_1]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int nb = n_per_row/QK4_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_1 * ib; + const float * qw = quant_weights + QK4_1 * ib; + for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = WSP_GGML_FP32_TO_FP16(d); + y[ib].m = WSP_GGML_FP32_TO_FP16(-min); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } +} + +size_t wsp_quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + if (!quant_weights) { + return wsp_ggml_wsp_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist); + } + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_1, n_per_row); + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +static void wsp_quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) { + static_assert(QK5_0 == 32, "QK5_0 must be 32"); + + if (!quant_weights) { + wsp_quantize_row_q5_0_reference(x, y, n_per_row); + return; + } + + float weight[QK5_0]; + int8_t L[QK5_0]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int nb = n_per_row/QK5_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_0 * ib; + const float * qw = quant_weights + QK5_0 * ib; + for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight); + y[ib].d = WSP_GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + + memcpy(&y[ib].qh, &qh, sizeof(qh)); + } +} + +size_t wsp_quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + if (!quant_weights) { + return wsp_ggml_wsp_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist); + } + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_0, n_per_row); + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +static void wsp_quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) { + static_assert(QK5_1 == 32, "QK5_1 must be 32"); + + if (!quant_weights) { + wsp_quantize_row_q5_1_reference(x, y, n_per_row); + return; + } + + float weight[QK5_1]; + uint8_t L[QK5_1], Laux[QK5_1]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int nb = n_per_row/QK5_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_1 * ib; + const float * qw = quant_weights + QK5_1 * ib; + for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = WSP_GGML_FP32_TO_FP16(d); + y[ib].m = WSP_GGML_FP32_TO_FP16(-min); + + uint32_t qh = 0; + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(&y[ib].qh, &qh, sizeof(qh)); + } +} + +size_t wsp_quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + if (!quant_weights) { + return wsp_ggml_wsp_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist); + } + size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_1, n_per_row); + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + +// ====================== "True" 2-bit (de)-quantization + +static const uint64_t iq2xxs_grid[256] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +}; + +static const uint64_t iq2xs_grid[512] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +}; + +static const uint8_t ksigns_iq2xs[128] = { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +}; + +static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + +void wsp_dewsp_quantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + for (int i = 0; i < nb; i++) { + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t)); + const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} + +// ====================== 2.3125 bpw (de)-quantization + +void wsp_dewsp_quantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + float db[2]; + + for (int i = 0; i < nb; i++) { + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511)); + const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9]; + for (int j = 0; j < 8; ++j) { + y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} + //===================================== Q8_K ============================================== void wsp_quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { @@ -2346,7 +3529,9 @@ void wsp_quantize_row_q8_K_reference(const float * restrict x, block_q8_K * rest x += QK_K; continue; } - const float iscale = -128.f/max; + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; for (int j = 0; j < QK_K; ++j) { int v = nearest_int(iscale*x[j]); y[i].qs[j] = MIN(127, v); @@ -2468,32 +3653,12 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + const int32x4_t p_0 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#endif } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); @@ -2776,32 +3941,12 @@ void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * re const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + const int32x4_t p_0 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; @@ -2963,32 +4108,12 @@ void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * re const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#endif + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); @@ -3275,32 +4400,12 @@ void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * re const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; @@ -3550,34 +4655,13 @@ void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * re const int8x16_t y1_0 = vld1q_s8(y1->qs); const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + wsp_ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + wsp_ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); - -#else - const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); - const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); - const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); - const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); - const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - - const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#endif + wsp_ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + wsp_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); @@ -3650,12 +4734,10 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re const int nb = n / QK_K; #ifdef __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); const uint8x16_t m4 = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + + const int32x4_t vzero = vdupq_n_s32(0); wsp_ggml_int8x16x2_t q2bytes; uint8_t aux[16]; @@ -3663,7 +4745,6 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re float sum = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); @@ -3677,7 +4758,7 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums); - const wsp_ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const wsp_ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), @@ -3689,20 +4770,9 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re // We use this macro instead of a function call because for some reason // the code runs 2-3% slower, even if the function is declared inline -#if defined(__ARM_FEATURE_DOTPROD) -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; -#else #define MULTIPLY_ACCUM_WITH_SCALE(index)\ - {\ - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ - isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ - } -#endif + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;\ @@ -3710,26 +4780,23 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ MULTIPLY_ACCUM_WITH_SCALE((index)); - for (int j = 0; j < QK_K/128; ++j) { - const wsp_ggml_uint8x16x2_t q2bits = wsp_ggml_vld1q_u8_x2(q2); q2 += 32; wsp_ggml_int8x16x2_t q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + MULTIPLY_ACCUM_WITH_SCALE(0); SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); is += 8; } - sum += d * isum; + sum += d * isum; } *s = sum; @@ -4043,11 +5110,9 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re const int nb = n / QK_K; #ifdef __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + + const int32x4_t vzero = vdupq_n_s32(0); wsp_ggml_int8x16x4_t q2bytes; @@ -4081,28 +5146,12 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); -#if defined(__ARM_FEATURE_DOTPROD) - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum1 += vaddvq_s16(p1) * scales[0]; - isum2 += vaddvq_s16(p2) * scales[1]; - - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum1 += vaddvq_s16(p3) * scales[2]; - isum2 += vaddvq_s16(p4) * scales[3]; -#endif - sum += d * (isum1 + isum2); + isum1 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; + sum += d * (isum1 + isum2); } *s = sum; @@ -4328,9 +5377,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re uint32_t utmp[4]; const uint8x16_t m3b = vdupq_n_u8(0x3); -#ifdef __ARM_FEATURE_DOTPROD const int32x4_t vzero = vdupq_n_s32(0); -#endif const uint8x16_t m0 = vdupq_n_u8(1); const uint8x16_t m1 = vshlq_n_u8(m0, 1); @@ -4382,22 +5429,11 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; -#else - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; + scale += 4; q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); @@ -4410,22 +5446,11 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; + scale += 4; if (j == 0) { @@ -4864,10 +5889,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re const int nb = n / QK_K; #ifdef __ARM_NEON - -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t vzero = vdupq_n_s32(0); -#endif + const int32x4_t vzero = vdupq_n_s32(0); const uint8x16_t m3b = vdupq_n_u8(0x3); const uint8x16_t mh = vdupq_n_u8(4); @@ -4908,22 +5930,10 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; -#endif + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; sum += d * isum; @@ -5228,11 +6238,8 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re uint32_t utmp[4]; #ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); -#ifdef __ARM_FEATURE_DOTPROD const int32x4_t mzero = vdupq_n_s32(0); -#endif wsp_ggml_int8x16x2_t q4bytes; wsp_ggml_int8x16x2_t q8bytes; @@ -5269,44 +6276,22 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re int32_t sumi2 = 0; for (int j = 0; j < QK_K/64; ++j) { - const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32; -#ifdef __ARM_FEATURE_DOTPROD q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - - q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; -#else - q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; -#endif + const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); @@ -5603,12 +6588,9 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re const int nb = n / QK_K; #ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); -#ifdef __ARM_FEATURE_DOTPROD const int32x4_t mzero = vdupq_n_s32(0); -#endif float sumf = 0; @@ -5636,41 +6618,20 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); -#ifdef __ARM_FEATURE_DOTPROD q8bytes = wsp_ggml_vld1q_s8_x4(q8); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; -#else - q8bytes = wsp_ggml_vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; - - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); - int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; - -#endif sumf += d * (sumi1 + sumi2); - } *s = sumf - sum_mins; @@ -5875,15 +6836,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re uint32_t utmp[4]; - #ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t mone = vdupq_n_u8(1); const uint8x16_t mtwo = vdupq_n_u8(2); -#if defined(__ARM_FEATURE_DOTPROD) const int32x4_t mzero = vdupq_n_s32(0); -#endif wsp_ggml_int8x16x4_t q5bytes; @@ -5938,28 +6895,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); -#if defined(__ARM_FEATURE_DOTPROD) - - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; -#else - - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; - - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; -#endif + sumi += vaddvq_s32(wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; } sumf += d * sumi - dmin * sumi_mins; - } *s = sumf; @@ -6311,12 +7251,9 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re const int nb = n / QK_K; #ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t mh = vdupq_n_u8(16); -#if defined(__ARM_FEATURE_DOTPROD) const int32x4_t mzero = vdupq_n_s32(0); -#endif wsp_ggml_int8x16x4_t q5bytes; wsp_ggml_uint8x16x4_t q5h; @@ -6348,32 +7285,12 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); -#if defined(__ARM_FEATURE_DOTPROD) - - int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); - int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); - int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); - int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + int32_t sumi1 = sc[0] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); - -#else - - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); - - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); - - sumf += d*sumi; -#endif - } *s = sumf; @@ -6600,13 +7517,10 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re const int nb = n / QK_K; #ifdef __ARM_NEON - float sum = 0; const uint8x16_t m4b = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) const int32x4_t vzero = vdupq_n_s32(0); -#endif //const int8x16_t m32s = vdupq_n_s8(32); const uint8x16_t mone = vdupq_n_u8(3); @@ -6626,7 +7540,7 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums); const int8x16_t scales = vld1q_s8(scale); - const wsp_ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + const wsp_ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), @@ -6658,31 +7572,13 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); -#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; scale += 4; -#else - - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; -#endif - q8bytes = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; shifted = vshrq_n_u8(qhbits.val[0], 4); @@ -6703,34 +7599,11 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); -#if defined(__ARM_FEATURE_DOTPROD) - - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; scale += 4; - - //for (int l = 0; l < 4; ++l) { - // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); - // isum += vaddvq_s32(p) * *scale++; - //} -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; -#endif - } //sum += isum * d_all * y[i].d; sum += d_all * y[i].d * (isum - 32 * isum_mins); @@ -7076,14 +7949,11 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re const int nb = n / QK_K; #ifdef __ARM_NEON - float sum = 0; const uint8x16_t m4b = vdupq_n_u8(0xF); const int8x16_t m32s = vdupq_n_s8(32); -#if defined(__ARM_FEATURE_DOTPROD) const int32x4_t vzero = vdupq_n_s32(0); -#endif const uint8x16_t mone = vdupq_n_u8(3); @@ -7119,26 +7989,10 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); -#if defined(__ARM_FEATURE_DOTPROD) - - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; -#else - - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif + isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; sum += isum * d_all * y[i].d; @@ -7380,3 +8234,958 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re } #endif + +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; + +void wsp_ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_iq2_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + wsp_ggml_int8x16x4_t q2u; + wsp_ggml_int8x16x4_t q2s; + wsp_ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; + +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void wsp_ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_iq2_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + wsp_ggml_int8x16x4_t q2u; + wsp_ggml_int8x16x4_t q2s; + wsp_ggml_int8x16x4_t q8b; + + int32x4x4_t scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + q8b = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); + const int32x4_t p2 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); + const int32x4_t p3 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); + const int32x4_t p4 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); + q2 += 8; + } + sumf += d*vaddvq_s32(sumi); + } + *s = 0.125f * sumf; + +#elif defined(__AVX2__) + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m127 = _mm_set1_epi16(127); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m128i aux_gindex, aux_sindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + const uint16_t * sindex = (const uint16_t *)&aux_sindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8; + aux_gindex = _mm_and_si128(q2_data, m511); + aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127); + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +// ================================ IQ2 quantization ============================================= + +typedef struct { + uint64_t * grid; + int * map; + uint16_t * neighbours; +} iq2_entry_t; + +static iq2_entry_t iq2_data[2] = { + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, +}; + +static inline int iq2_data_index(int grid_size) { + WSP_GGML_ASSERT(grid_size == 256 || grid_size == 512); + return grid_size == 256 ? 0 : 1; +} + +static int iq2_compare_func(const void * left, const void * right) { + const int * l = (const int *)left; + const int * r = (const int *)right; + return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; +} + +void iq2xs_init_impl(int grid_size) { + const int gindex = iq2_data_index(grid_size); + if (iq2_data[gindex].grid) { + return; + } + static const uint16_t kgrid_256[256] = { + 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97, + 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642, + 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288, + 1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113, + 2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240, + 4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400, + 5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260, + 8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872, + 10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516, + 16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561, + 17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488, + 20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545, + 22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874, + 25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856, + 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142, + 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268, + }; + static const uint16_t kgrid_512[512] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257, + 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340, + 352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597, + 640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096, + 1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348, + 1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065, + 2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441, + 2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160, + 4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372, + 4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125, + 5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652, + 5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197, + 8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549, + 8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894, + 10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388, + 16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480, + 16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773, + 16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473, + 17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436, + 18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497, + 20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162, + 21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528, + 22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745, + 24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234, + 32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025, + 33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810, + 33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984, + 35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462, + 37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960, + 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048, + 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690, + }; + const int kmap_size = 43692; + const int nwant = 2; + const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + uint64_t * kgrid_q2xs; + int * kmap_q2xs; + uint16_t * kneighbors_q2xs; + + printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size); + uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t)); + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + for (int i = 0; i < 8; ++i) { + int l = (kgrid[k] >> 2*i) & 0x3; + pos[i] = 2*l + 1; + } + } + kgrid_q2xs = the_grid; + iq2_data[gindex].grid = the_grid; + kmap_q2xs = (int *)malloc(kmap_size*sizeof(int)); + iq2_data[gindex].map = kmap_q2xs; + for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1; + uint64_t aux64; + uint8_t * aux8 = (uint8_t *)&aux64; + for (int i = 0; i < grid_size; ++i) { + aux64 = kgrid_q2xs[i]; + uint16_t index = 0; + for (int k=0; k<8; ++k) { + uint16_t q = (aux8[k] - 1)/2; + index |= (q << 2*k); + } + kmap_q2xs[index] = i; + } + int8_t pos[8]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq2_data[gindex].neighbours = kneighbors_q2xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + kmap_q2xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} + +void iq2xs_free_impl(int grid_size) { + WSP_GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024); + const int gindex = iq2_data_index(grid_size); + if (iq2_data[gindex].grid) { + free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL; + free(iq2_data[gindex].map); iq2_data[gindex].map = NULL; + free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL; + } +} + +static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) { + int num_neighbors = neighbours[0]; + WSP_GGML_ASSERT(num_neighbors > 0); + float best_d2 = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = pg[i]; + float diff = scale*q - xval[i]; + d2 += weight[i]*diff*diff; + } + if (d2 < best_d2) { + best_d2 = d2; grid_index = neighbours[j]; + } + } + WSP_GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + +static void wsp_quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { + + const int gindex = iq2_data_index(256); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + WSP_GGML_ASSERT(quant_weights && "missing quantization weights"); + WSP_GGML_ASSERT(kgrid_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?"); + WSP_GGML_ASSERT(kmap_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?"); + WSP_GGML_ASSERT(kneighbors_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?"); + WSP_GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 3; + + const int nbl = n/256; + + block_iq2_xxs * y = vy; + + float scales[QK_K/32]; + float weight[32]; + float xval[32]; + int8_t L[32]; + int8_t Laux[32]; + float waux[32]; + bool is_on_grid[4]; + bool is_on_grid_aux[4]; + uint8_t block_signs[4]; + uint32_t q2[2*(QK_K/32)]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = WSP_GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 4; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 4; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 32; ++i) L[i] = Laux[i]; + for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 4; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index); + for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + WSP_GGML_ASSERT(false); + } + q2[2*ib+0] |= (grid_index << 8*k); + q2[2*ib+1] |= (block_signs[k] << 7*k); + } + WSP_GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; + } + + float d = max_scale/31; + y[ibl].d = WSP_GGML_FP32_TO_FP16(d); + float id = 1/d; + float sumqx = 0, sumq2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + q2[2*ib+1] |= ((uint32_t)l << 28); + const float * xb = xbl + 32*ib; + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib); + const float db = d * (1 + 2*l); + uint32_t u = 0; + for (int k = 0; k < 4; ++k) { + const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127); + const float * xk = xb + 8*k; + const float * wk = weight + 8*k; + const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]); + float best_mse = 0; int best_index = aux8[k]; + for (int j = 0; j < 8; ++j) { + float diff = db * grid[j] * signs[j] - xk[j]; + best_mse += wk[j] * diff * diff; + } + for (int idx = 0; idx < 256; ++idx) { + grid = (const uint8_t *)(kgrid_q2xs + idx); + float mse = 0; + for (int j = 0; j < 8; ++j) { + float diff = db * grid[j] * signs[j] - xk[j]; + mse += wk[j] * diff * diff; + } + if (mse < best_mse) { + best_mse = mse; best_index = idx; + } + } + u |= (best_index << 8*k); + grid = (const uint8_t *)(kgrid_q2xs + best_index); + //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]); + for (int j = 0; j < 8; ++j) { + float q = db * grid[j] * signs[j]; + sumqx += wk[j] * q * xk[j]; + sumq2 += wk[j] * q * q; + } + } + q2[2*ib] = u; + if (sumq2 > 0) y[ibl].d = WSP_GGML_FP32_TO_FP16(d*sumqx/sumq2); + } + memcpy(y[ibl].qs, q2, QK_K/4); + } +} + +static void wsp_quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { + + const int gindex = iq2_data_index(512); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + WSP_GGML_ASSERT(quant_weights && "missing quantization weights"); + WSP_GGML_ASSERT(kmap_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?"); + WSP_GGML_ASSERT(kgrid_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?"); + WSP_GGML_ASSERT(kneighbors_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?"); + WSP_GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 3; + + const int nbl = n/256; + + block_iq2_xs * y = vy; + + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; + uint16_t q2[2*(QK_K/16)]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = WSP_GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + memset(y[ibl].scales, 0, QK_K/32); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 16); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + WSP_GGML_ASSERT(false); + } + q2[2*ib+k] = grid_index | (block_signs[k] << 9); + } + WSP_GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; + } + + float d = max_scale/31; + y[ibl].d = WSP_GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); + } + memcpy(y[ibl].qs, q2, QK_K/4); + + } +} + +size_t wsp_quantize_iq2_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + WSP_GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xxs); + } + return nrow * nblock * sizeof(block_iq2_xxs); +} + +size_t wsp_quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + WSP_GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + wsp_quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xs); + } + return nrow * nblock * sizeof(block_iq2_xs); +} + diff --git a/cpp/ggml-quants.h b/cpp/ggml-quants.h index b0e60cf..abdba2e 100644 --- a/cpp/ggml-quants.h +++ b/cpp/ggml-quants.h @@ -70,7 +70,7 @@ static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block s // 2-bit quantization // weight is represented as x = a * q + b // 16 blocks of 16 elements each -// Effectively 2.5625 bits per weight +// Effectively 2.625 bits per weight typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants @@ -165,6 +165,22 @@ typedef struct { } block_q8_K; static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); +// (Almost) "true" 2-bit quantization. +// Due to the need to use blocks as per ggml dsign, it ends up using +// 2.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + wsp_ggml_fp16_t d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(wsp_ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); + +// 2.3125 bpw quants +typedef struct { + wsp_ggml_fp16_t d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(wsp_ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); // Quantization void wsp_quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); @@ -209,6 +225,8 @@ void wsp_dewsp_quantize_row_q4_K(const block_q4_K * restrict x, float * restrict void wsp_dewsp_quantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); void wsp_dewsp_quantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); void wsp_dewsp_quantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); +void wsp_dewsp_quantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k); +void wsp_dewsp_quantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k); // Dot product void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -222,3 +240,23 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +// +// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") +// +size_t wsp_quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q5_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q6_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q4_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q4_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t wsp_quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); + +void iq2xs_init_impl(int grid_size); +void iq2xs_free_impl(int grid_size); diff --git a/cpp/ggml.c b/cpp/ggml.c index 014248e..93e9928 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -132,7 +132,7 @@ void wsp_ggml_print_backtrace(void) { "-ex", "bt -frame-info source-and-location", "-ex", "detach", "-ex", "quit", - NULL); + (char *) NULL); } else { waitpid(pid, NULL, 0); } @@ -573,6 +573,28 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = { .vec_dot = wsp_ggml_vec_dot_q6_K_q8_K, .vec_dot_type = WSP_GGML_TYPE_Q8_K, }, + [WSP_GGML_TYPE_IQ2_XXS] = { + .type_name = "iq2_xxs", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xxs), + .is_quantized = true, + .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_iq2_xxs, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = wsp_ggml_vec_dot_iq2_xxs_q8_K, + .vec_dot_type = WSP_GGML_TYPE_Q8_K, + }, + [WSP_GGML_TYPE_IQ2_XS] = { + .type_name = "iq2_xs", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xs), + .is_quantized = true, + .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_iq2_xs, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = wsp_ggml_vec_dot_iq2_xs_q8_K, + .vec_dot_type = WSP_GGML_TYPE_Q8_K, + }, [WSP_GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -1962,19 +1984,19 @@ void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx) { WSP_GGML_PRINT("%s: --- end ---\n", __func__); } -int64_t wsp_ggml_nelements(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL int64_t wsp_ggml_nelements(const struct wsp_ggml_tensor * tensor) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -int64_t wsp_ggml_nrows(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL int64_t wsp_ggml_nrows(const struct wsp_ggml_tensor * tensor) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -size_t wsp_ggml_nbytes(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL size_t wsp_ggml_nbytes(const struct wsp_ggml_tensor * tensor) { size_t nbytes; size_t blck_size = wsp_ggml_blck_size(tensor->type); if (blck_size == 1) { @@ -1997,33 +2019,32 @@ size_t wsp_ggml_nbytes_pad(const struct wsp_ggml_tensor * tensor) { return WSP_GGML_PAD(wsp_ggml_nbytes(tensor), WSP_GGML_MEM_ALIGN); } -size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split) { - static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); - - return (nrows_split*tensor->ne[0]*wsp_ggml_type_size(tensor->type))/wsp_ggml_blck_size(tensor->type); -} - -int wsp_ggml_blck_size(enum wsp_ggml_type type) { +WSP_GGML_CALL int wsp_ggml_blck_size(enum wsp_ggml_type type) { return type_traits[type].blck_size; } -size_t wsp_ggml_type_size(enum wsp_ggml_type type) { +WSP_GGML_CALL size_t wsp_ggml_type_size(enum wsp_ggml_type type) { return type_traits[type].type_size; } -float wsp_ggml_type_sizef(enum wsp_ggml_type type) { - return ((float)(type_traits[type].type_size))/type_traits[type].blck_size; +WSP_GGML_CALL size_t wsp_ggml_row_size(enum wsp_ggml_type type, int64_t ne) { + assert(ne % wsp_ggml_blck_size(type) == 0); + return wsp_ggml_type_size(type)*ne/wsp_ggml_blck_size(type); +} + +double wsp_ggml_type_sizef(enum wsp_ggml_type type) { + return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; } -const char * wsp_ggml_type_name(enum wsp_ggml_type type) { +WSP_GGML_CALL const char * wsp_ggml_type_name(enum wsp_ggml_type type) { return type_traits[type].type_name; } -bool wsp_ggml_is_quantized(enum wsp_ggml_type type) { +WSP_GGML_CALL bool wsp_ggml_is_quantized(enum wsp_ggml_type type) { return type_traits[type].is_quantized; } -const char * wsp_ggml_op_name(enum wsp_ggml_op op) { +WSP_GGML_CALL const char * wsp_ggml_op_name(enum wsp_ggml_op op) { return WSP_GGML_OP_NAME[op]; } @@ -2035,7 +2056,7 @@ const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op) { return WSP_GGML_UNARY_OP_NAME[op]; } -const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) { +WSP_GGML_CALL const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) { if (t->op == WSP_GGML_OP_UNARY) { enum wsp_ggml_unary_op uop = wsp_ggml_get_unary_op(t); return wsp_ggml_unary_op_name(uop); @@ -2045,28 +2066,41 @@ const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) { } } -size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) { return wsp_ggml_type_size(tensor->type); } -static inline bool wsp_ggml_is_scalar(const struct wsp_ggml_tensor * tensor) { +bool wsp_ggml_is_scalar(const struct wsp_ggml_tensor * tensor) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } -static inline bool wsp_ggml_is_vector(const struct wsp_ggml_tensor * tensor) { +bool wsp_ggml_is_vector(const struct wsp_ggml_tensor * tensor) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } -static inline bool wsp_ggml_is_matrix(const struct wsp_ggml_tensor * tensor) { +bool wsp_ggml_is_matrix(const struct wsp_ggml_tensor * tensor) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[2] == 1 && tensor->ne[3] == 1; } +bool wsp_ggml_is_3d(const struct wsp_ggml_tensor * tensor) { + return tensor->ne[3] == 1; +} + +int wsp_ggml_n_dims(const struct wsp_ggml_tensor * tensor) { + for (int i = WSP_GGML_MAX_DIMS - 1; i >= 1; --i) { + if (tensor->ne[i] > 1) { + return i + 1; + } + } + return 1; +} + static inline bool wsp_ggml_can_mul_mat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); @@ -2099,6 +2133,8 @@ enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype) { case WSP_GGML_FTYPE_MOSTLY_Q4_K: wtype = WSP_GGML_TYPE_Q4_K; break; case WSP_GGML_FTYPE_MOSTLY_Q5_K: wtype = WSP_GGML_TYPE_Q5_K; break; case WSP_GGML_FTYPE_MOSTLY_Q6_K: wtype = WSP_GGML_TYPE_Q6_K; break; + case WSP_GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = WSP_GGML_TYPE_IQ2_XXS; break; + case WSP_GGML_FTYPE_MOSTLY_IQ2_XS: wtype = WSP_GGML_TYPE_IQ2_XS; break; case WSP_GGML_FTYPE_UNKNOWN: wtype = WSP_GGML_TYPE_COUNT; break; case WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = WSP_GGML_TYPE_COUNT; break; } @@ -2112,11 +2148,11 @@ size_t wsp_ggml_tensor_overhead(void) { return WSP_GGML_OBJECT_SIZE + WSP_GGML_TENSOR_SIZE; } -bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } -bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); return @@ -2135,7 +2171,7 @@ static inline bool wsp_ggml_is_contiguous_except_dim_1(const struct wsp_ggml_ten tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } -bool wsp_ggml_is_permuted(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL bool wsp_ggml_is_permuted(const struct wsp_ggml_tensor * tensor) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; @@ -2312,6 +2348,10 @@ struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params) { } void wsp_ggml_free(struct wsp_ggml_context * ctx) { + if (ctx == NULL) { + return; + } + // make this function thread safe wsp_ggml_critical_section_start(); @@ -2371,20 +2411,8 @@ size_t wsp_ggml_get_mem_size(const struct wsp_ggml_context * ctx) { size_t wsp_ggml_get_max_tensor_size(const struct wsp_ggml_context * ctx) { size_t max_size = 0; - struct wsp_ggml_object * obj = ctx->objects_begin; - - while (obj != NULL) { - if (obj->type == WSP_GGML_OBJECT_TENSOR) { - struct wsp_ggml_tensor * tensor = (struct wsp_ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); - - const size_t size = wsp_ggml_nbytes(tensor); - - if (max_size < size) { - max_size = size; - } - } - - obj = obj->next; + for (struct wsp_ggml_tensor * tensor = wsp_ggml_get_first_tensor(ctx); tensor != NULL; tensor = wsp_ggml_get_next_tensor(ctx, tensor)) { + max_size = MAX(max_size, wsp_ggml_nbytes(tensor)); } return max_size; @@ -2473,7 +2501,7 @@ static struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl( view_src = view_src->view_src; } - size_t data_size = wsp_ggml_type_size(type)*(ne[0]/wsp_ggml_blck_size(type)); + size_t data_size = wsp_ggml_row_size(type, ne[0]); for (int i = 1; i < n_dims; i++) { data_size *= ne[i]; } @@ -2516,7 +2544,6 @@ static struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl( /*.type =*/ type, /*.backend =*/ WSP_GGML_BACKEND_CPU, /*.buffer =*/ NULL, - /*.n_dims =*/ n_dims, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, /*.op =*/ WSP_GGML_OP_NONE, @@ -2623,7 +2650,7 @@ struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float v } struct wsp_ggml_tensor * wsp_ggml_dup_tensor(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src) { - return wsp_ggml_new_tensor(ctx, src->type, src->n_dims, src->ne); + return wsp_ggml_new_tensor(ctx, src->type, WSP_GGML_MAX_DIMS, src->ne); } static void wsp_ggml_set_op_params(struct wsp_ggml_tensor * tensor, const void * params, size_t params_size) { @@ -3046,7 +3073,7 @@ float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor) { return (float *)(tensor->data); } -enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor) { +WSP_GGML_CALL enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor) { WSP_GGML_ASSERT(tensor->op == WSP_GGML_OP_UNARY); return (enum wsp_ggml_unary_op) wsp_ggml_get_op_params_i32(tensor, 0); } @@ -3072,7 +3099,7 @@ struct wsp_ggml_tensor * wsp_ggml_format_name(struct wsp_ggml_tensor * tensor, c struct wsp_ggml_tensor * wsp_ggml_view_tensor( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src) { - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src, 0); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, src->type, WSP_GGML_MAX_DIMS, src->ne, src, 0); wsp_ggml_format_name(result, "%s (view)", src->name); for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) { @@ -3082,7 +3109,7 @@ struct wsp_ggml_tensor * wsp_ggml_view_tensor( return result; } -struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx) { +struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(const struct wsp_ggml_context * ctx) { struct wsp_ggml_object * obj = ctx->objects_begin; char * const mem_buffer = ctx->mem_buffer; @@ -3098,7 +3125,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx return NULL; } -struct wsp_ggml_tensor * wsp_ggml_get_next_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor) { +struct wsp_ggml_tensor * wsp_ggml_get_next_tensor(const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor) { struct wsp_ggml_object * obj = (struct wsp_ggml_object *) ((char *)tensor - WSP_GGML_OBJECT_SIZE); obj = obj->next; @@ -3230,10 +3257,10 @@ static struct wsp_ggml_tensor * wsp_ggml_add_cast_impl( is_node = true; } - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, type, a->n_dims, a->ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, type, WSP_GGML_MAX_DIMS, a->ne); result->op = WSP_GGML_OP_ADD; - result->grad = is_node ? wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, a->ne) : NULL; + result->grad = is_node ? wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, WSP_GGML_MAX_DIMS, a->ne) : NULL; result->src[0] = a; result->src[1] = b; @@ -3602,12 +3629,12 @@ struct wsp_ggml_tensor * wsp_ggml_sum_rows( is_node = true; } - int64_t ne[4] = {1,1,1,1}; - for (int i=1; in_dims; ++i) { + int64_t ne[WSP_GGML_MAX_DIMS] = { 1 }; + for (int i = 1; i < WSP_GGML_MAX_DIMS; ++i) { ne[i] = a->ne[i]; } - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, a->n_dims, ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, WSP_GGML_MAX_DIMS, ne); result->op = WSP_GGML_OP_SUM_ROWS; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -3628,8 +3655,8 @@ struct wsp_ggml_tensor * wsp_ggml_mean( is_node = true; } - int64_t ne[WSP_GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, ne); + int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); result->op = WSP_GGML_OP_MEAN; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -3651,8 +3678,7 @@ struct wsp_ggml_tensor * wsp_ggml_argmax( is_node = true; } - int64_t ne[WSP_GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, a->ne[1]); result->op = WSP_GGML_OP_ARGMAX; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -3675,7 +3701,7 @@ struct wsp_ggml_tensor * wsp_ggml_repeat( is_node = true; } - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, WSP_GGML_MAX_DIMS, b->ne); result->op = WSP_GGML_OP_REPEAT; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -3702,7 +3728,7 @@ struct wsp_ggml_tensor * wsp_ggml_repeat_back( return a; } - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, WSP_GGML_MAX_DIMS, b->ne); result->op = WSP_GGML_OP_REPEAT_BACK; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -4043,7 +4069,6 @@ static struct wsp_ggml_tensor * wsp_ggml_group_norm_impl( result->op = WSP_GGML_OP_GROUP_NORM; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; // TODO: maybe store epsilon here? return result; } @@ -4078,7 +4103,7 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat( } const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); result->op = WSP_GGML_OP_MUL_MAT; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -4088,6 +4113,14 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat( return result; } +void wsp_ggml_mul_mat_set_prec( + struct wsp_ggml_tensor * a, + enum wsp_ggml_prec prec) { + const int32_t prec_i32 = (int32_t) prec; + + wsp_ggml_set_op_params_i32(a, 0, prec_i32); +} + // wsp_ggml_mul_mat_id struct wsp_ggml_tensor * wsp_ggml_mul_mat_id( @@ -4112,7 +4145,7 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat_id( } const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); wsp_ggml_set_op_params_i32(result, 0, id); wsp_ggml_set_op_params_i32(result, 1, n_as); @@ -4150,7 +4183,7 @@ struct wsp_ggml_tensor * wsp_ggml_out_prod( // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); result->op = WSP_GGML_OP_OUT_PROD; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -4165,23 +4198,23 @@ struct wsp_ggml_tensor * wsp_ggml_out_prod( static struct wsp_ggml_tensor * wsp_ggml_scale_impl( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b, + float s, bool inplace) { - WSP_GGML_ASSERT(wsp_ggml_is_scalar(b)); WSP_GGML_ASSERT(wsp_ggml_is_padded_1d(a)); bool is_node = false; - if (a->grad || b->grad) { + if (a->grad) { is_node = true; } struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + wsp_ggml_set_op_params(result, &s, sizeof(s)); + result->op = WSP_GGML_OP_SCALE; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -4189,15 +4222,15 @@ static struct wsp_ggml_tensor * wsp_ggml_scale_impl( struct wsp_ggml_tensor * wsp_ggml_scale( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b) { - return wsp_ggml_scale_impl(ctx, a, b, false); + float s) { + return wsp_ggml_scale_impl(ctx, a, s, false); } struct wsp_ggml_tensor * wsp_ggml_scale_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b) { - return wsp_ggml_scale_impl(ctx, a, b, true); + float s) { + return wsp_ggml_scale_impl(ctx, a, s, true); } // wsp_ggml_set @@ -4294,13 +4327,13 @@ struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace( static struct wsp_ggml_tensor * wsp_ggml_cpy_impl( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b, - bool inplace) { + struct wsp_ggml_tensor * b) { WSP_GGML_ASSERT(wsp_ggml_nelements(a) == wsp_ggml_nelements(b)); bool is_node = false; - if (!inplace && (a->grad || b->grad)) { + if (a->grad || b->grad) { + // inplace is false and either one have a grad is_node = true; } @@ -4324,29 +4357,38 @@ struct wsp_ggml_tensor * wsp_ggml_cpy( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b) { - return wsp_ggml_cpy_impl(ctx, a, b, false); + return wsp_ggml_cpy_impl(ctx, a, b); } -struct wsp_ggml_tensor * wsp_ggml_cpy_inplace( +struct wsp_ggml_tensor * wsp_ggml_cast( struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b) { - return wsp_ggml_cpy_impl(ctx, a, b, true); + struct wsp_ggml_tensor * a, + enum wsp_ggml_type type) { + bool is_node = false; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, type, WSP_GGML_MAX_DIMS, a->ne); + wsp_ggml_format_name(result, "%s (copy)", a->name); + + result->op = WSP_GGML_OP_CPY; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = result; + + return result; } // wsp_ggml_cont static struct wsp_ggml_tensor * wsp_ggml_cont_impl( struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a, - bool inplace) { + struct wsp_ggml_tensor * a) { bool is_node = false; - if (!inplace && a->grad) { + if (a->grad) { is_node = true; } - struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); + struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); wsp_ggml_format_name(result, "%s (cont)", a->name); result->op = WSP_GGML_OP_CONT; @@ -4359,13 +4401,7 @@ static struct wsp_ggml_tensor * wsp_ggml_cont_impl( struct wsp_ggml_tensor * wsp_ggml_cont( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a) { - return wsp_ggml_cont_impl(ctx, a, false); -} - -struct wsp_ggml_tensor * wsp_ggml_cont_inplace( - struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a) { - return wsp_ggml_cont_impl(ctx, a, true); + return wsp_ggml_cont_impl(ctx, a); } // make contiguous, with new shape @@ -4435,7 +4471,7 @@ struct wsp_ggml_tensor * wsp_ggml_reshape( //WSP_GGML_ASSERT(false); } - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_impl(ctx, a->type, WSP_GGML_MAX_DIMS, b->ne, a, 0); wsp_ggml_format_name(result, "%s (reshaped)", a->name); result->op = WSP_GGML_OP_RESHAPE; @@ -4761,8 +4797,11 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows( } // TODO: implement non F32 return - //struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); + enum wsp_ggml_type type = WSP_GGML_TYPE_F32; + if (a->type == WSP_GGML_TYPE_I32) { + type = a->type; + } + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); result->op = WSP_GGML_OP_GET_ROWS; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -4813,7 +4852,7 @@ struct wsp_ggml_tensor * wsp_ggml_diag( } const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, a->type, 4, ne); result->op = WSP_GGML_OP_DIAG; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -5460,7 +5499,7 @@ struct wsp_ggml_tensor * wsp_ggml_pool_1d( is_node = true; } - const int64_t ne[3] = { + const int64_t ne[2] = { wsp_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), a->ne[1], }; @@ -5535,7 +5574,6 @@ static struct wsp_ggml_tensor * wsp_ggml_upscale_impl( result->op_params[0] = scale_factor; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5579,7 +5617,7 @@ struct wsp_ggml_tensor * wsp_ggml_argsort( enum wsp_ggml_sort_order order) { bool is_node = false; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, a->ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, WSP_GGML_MAX_DIMS, a->ne); wsp_ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5626,7 +5664,7 @@ struct wsp_ggml_tensor * wsp_ggml_flash_attn( } //struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, q); - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, q->n_dims, q->ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, WSP_GGML_MAX_DIMS, q->ne); int32_t t = masked ? 1 : 0; wsp_ggml_set_op_params(result, &t, sizeof(t)); @@ -5659,7 +5697,7 @@ struct wsp_ggml_tensor * wsp_ggml_flash_ff( } //struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, a->ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, WSP_GGML_MAX_DIMS, a->ne); result->op = WSP_GGML_OP_FLASH_FF; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -5775,7 +5813,6 @@ struct wsp_ggml_tensor * wsp_ggml_win_part( const int np = npx*npy; const int64_t ne[4] = { a->ne[0], w, w, np, }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); int32_t params[] = { npx, npy, w }; @@ -5841,7 +5878,6 @@ struct wsp_ggml_tensor * wsp_ggml_get_rel_pos( result->op = WSP_GGML_OP_GET_REL_POS; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6936,14 +6972,165 @@ static void wsp_ggml_compute_forward_dup_f32( } } -static void wsp_ggml_compute_forward_dup( +// A simplified version of wsp_ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. +static void wsp_ggml_compute_forward_dup_bytes( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, struct wsp_ggml_tensor * dst) { - if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(dst) && src0->type == dst->type) { + WSP_GGML_ASSERT(wsp_ggml_nelements(dst) == wsp_ggml_nelements(src0)); + WSP_GGML_ASSERT(src0->type == dst->type); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + if (wsp_ggml_is_contiguous(src0) && wsp_ggml_is_contiguous(dst)) { wsp_ggml_compute_forward_dup_same_cont(params, src0, dst); return; } + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + const size_t type_size = wsp_ggml_type_size(src0->type); + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == type_size && nb0 == type_size) { + // copy by rows + const size_t rs = ne00 * type_size; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (wsp_ggml_is_contiguous(dst)) { + size_t id = 0; + char * dst_ptr = (char *) dst->data; + const size_t rs = ne00 * type_size; + + if (nb00 == type_size) { + // src0 is contigous on first dimension, copy by rows + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, type_size); + + id += type_size; + } + } + id += rs * (ne01 - ir1); + } + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, type_size); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } +} + +static void wsp_ggml_compute_forward_dup( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + if (src0->type == dst->type) { + wsp_ggml_compute_forward_dup_bytes(params, src0, dst); + return; + } + switch (src0->type) { case WSP_GGML_TYPE_F16: { @@ -7280,6 +7467,8 @@ static void wsp_ggml_compute_forward_add( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: { wsp_ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7544,6 +7733,8 @@ static void wsp_ggml_compute_forward_add1( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: { wsp_ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -7658,6 +7849,8 @@ static void wsp_ggml_compute_forward_acc( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: default: { WSP_GGML_ASSERT(false); @@ -7759,10 +7952,10 @@ static void wsp_ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; -// TODO: OpenCL kernel support broadcast #ifdef WSP_GGML_USE_CLBLAST if (src1->backend == WSP_GGML_BACKEND_GPU) { - WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1)); + // TODO: OpenCL kernel support full broadcast + WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(src1, src0)); if (ith == 0) { wsp_ggml_cl_mul(src0, src1, dst); } @@ -8402,10 +8595,12 @@ static void wsp_ggml_compute_forward_repeat( struct wsp_ggml_tensor * dst) { switch (src0->type) { case WSP_GGML_TYPE_F16: + case WSP_GGML_TYPE_I16: { wsp_ggml_compute_forward_repeat_f16(params, src0, dst); } break; case WSP_GGML_TYPE_F32: + case WSP_GGML_TYPE_I32: { wsp_ggml_compute_forward_repeat_f32(params, src0, dst); } break; @@ -8548,6 +8743,7 @@ static void wsp_ggml_compute_forward_concat( struct wsp_ggml_tensor* dst) { switch (src0->type) { case WSP_GGML_TYPE_F32: + case WSP_GGML_TYPE_I32: { wsp_ggml_compute_forward_concat_f32(params, src0, src1, dst); } break; @@ -9159,6 +9355,8 @@ static void wsp_ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + WSP_GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -9228,6 +9426,8 @@ static void wsp_ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + WSP_GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -9541,10 +9741,10 @@ static void wsp_ggml_compute_forward_group_norm( #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster -static bool wsp_ggml_compute_forward_mul_mat_use_blas( - const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst) { +static bool wsp_ggml_compute_forward_mul_mat_use_blas(struct wsp_ggml_tensor * dst) { + const struct wsp_ggml_tensor * src0 = dst->src[0]; + const struct wsp_ggml_tensor * src1 = dst->src[1]; + //const int64_t ne00 = src0->ne[0]; //const int64_t ne01 = src0->ne[1]; @@ -9571,16 +9771,11 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas( } #endif -// off1 = offset in i11 and i1 -// cne1 = ne11 and ne1 -// in a normal matrix multiplication, off1 = 0 and cne1 = ne1 -// during WSP_GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1 static void wsp_ggml_compute_forward_mul_mat( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, - struct wsp_ggml_tensor * dst, - int64_t off1, int64_t cne1) { + struct wsp_ggml_tensor * dst) { int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); @@ -9629,7 +9824,7 @@ static void wsp_ggml_compute_forward_mul_mat( #endif #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) - if (wsp_ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + if (wsp_ggml_compute_forward_mul_mat_use_blas(dst)) { if (params->ith != 0) { return; } @@ -9648,9 +9843,9 @@ static void wsp_ggml_compute_forward_mul_mat( const int64_t i03 = i13/r3; const int64_t i02 = i12/r2; - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13); - float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3); + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); if (type != WSP_GGML_TYPE_F32) { float * const wdata = params->wdata; @@ -9667,7 +9862,7 @@ static void wsp_ggml_compute_forward_mul_mat( } cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - cne1, ne01, ne10, + ne1, ne01, ne10, 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); @@ -9683,10 +9878,10 @@ static void wsp_ggml_compute_forward_mul_mat( if (params->type == WSP_GGML_TASK_INIT) { if (src1->type != vec_dot_type) { char * wdata = params->wdata; - const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type); + const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10); assert(params->wsize >= ne11*ne12*ne13*row_size); - assert(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -9706,10 +9901,10 @@ static void wsp_ggml_compute_forward_mul_mat( } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type); + const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1*ne12*ne13; // src1 rows + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = ne1*ne12*ne13; // src1 rows //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); @@ -9751,9 +9946,9 @@ static void wsp_ggml_compute_forward_mul_mat( for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t i13 = (ir1/(ne12*cne1)); - const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; - const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1; + const int64_t i13 = (ir1/(ne12*ne1)); + const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; + const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); // broadcast src0 into src1 const int64_t i03 = i13/r3; @@ -9793,28 +9988,191 @@ static void wsp_ggml_compute_forward_mul_mat( static void wsp_ggml_compute_forward_mul_mat_id( const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * ids, const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { - if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { - // during WSP_GGML_TASK_INIT the entire src1 is converted to vec_dot_type - wsp_ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]); - return; - } + const struct wsp_ggml_tensor * src0 = dst->src[2]; // only for WSP_GGML_TENSOR_BINARY_OP_LOCALS + + WSP_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum wsp_ggml_type type = src0->type; + + const bool src1_cont = wsp_ggml_is_contiguous(src1); + + wsp_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum wsp_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + wsp_ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + + WSP_GGML_ASSERT(ne0 == ne01); + WSP_GGML_ASSERT(ne1 == ne11); + WSP_GGML_ASSERT(ne2 == ne12); + WSP_GGML_ASSERT(ne3 == ne13); - const struct wsp_ggml_tensor * ids = src0; + // we don't support permuted src0 or src1 + WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type)); + WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + WSP_GGML_ASSERT(nb0 <= nb1); + WSP_GGML_ASSERT(nb1 <= nb2); + WSP_GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + // row groups const int id = wsp_ggml_get_op_params_i32(dst, 0); const int n_as = wsp_ggml_get_op_params_i32(dst, 1); - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + WSP_GGML_PAD(wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)), sizeof(int64_t)); + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11] + + #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + + if (params->type == WSP_GGML_TASK_INIT) { + char * wdata = params->wdata; + if (src1->type != vec_dot_type) { + const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10); + + assert(params->wsize >= ne11*ne12*ne13*row_size); + assert(src1->type == WSP_GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } + } + } + } + + // initialize matrix_row_counts + WSP_GGML_ASSERT(wdata == wdata_src1_end); + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); - WSP_GGML_ASSERT(row_id >= 0 && row_id < n_as); + WSP_GGML_ASSERT(row_id >= 0 && row_id < n_as); + MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01; + matrix_row_counts[row_id] += 1; + } + + return; + } - const struct wsp_ggml_tensor * src0_row = dst->src[row_id + 2]; - wsp_ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1); + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; } + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const struct wsp_ggml_tensor * src0_cur = dst->src[cur_a + 2]; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = wsp_ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1*ne12*ne13; // src1 rows + + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + + // distribute the thread work across the inner or outer loop based on which one is larger + + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + + // threads with no work simply yield (not sure if it helps) + if (ir010 >= ir011 || ir110 >= ir111) { + sched_yield(); + continue; + } + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // attempt to reduce false-sharing (does not seem to make a difference) + float tmp[16]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { + const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix + const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; + const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1); + const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); + + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); + } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } + } + + #undef MMID_MATRIX_ROW } // wsp_ggml_compute_forward_out_prod @@ -10134,6 +10492,8 @@ static void wsp_ggml_compute_forward_out_prod( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: { wsp_ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); } break; @@ -10158,19 +10518,18 @@ static void wsp_ggml_compute_forward_out_prod( static void wsp_ggml_compute_forward_scale_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst)); - WSP_GGML_ASSERT(wsp_ggml_is_scalar(src1)); if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { return; } // scale factor - const float v = *(float *) src1->data; + float v; + memcpy(&v, dst->op_params, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -10201,12 +10560,11 @@ static void wsp_ggml_compute_forward_scale_f32( static void wsp_ggml_compute_forward_scale( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, - const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { switch (src0->type) { case WSP_GGML_TYPE_F32: { - wsp_ggml_compute_forward_scale_f32(params, src0, src1, dst); + wsp_ggml_compute_forward_scale_f32(params, src0, dst); } break; default: { @@ -10310,6 +10668,8 @@ static void wsp_ggml_compute_forward_set( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: default: { WSP_GGML_ASSERT(false); @@ -10504,6 +10864,8 @@ static void wsp_ggml_compute_forward_get_rows( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: { wsp_ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -10512,6 +10874,7 @@ static void wsp_ggml_compute_forward_get_rows( wsp_ggml_compute_forward_get_rows_f16(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F32: + case WSP_GGML_TYPE_I32: { wsp_ggml_compute_forward_get_rows_f32(params, src0, src1, dst); } break; @@ -11139,6 +11502,8 @@ static void wsp_ggml_compute_forward_alibi( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: case WSP_GGML_TYPE_Q8_K: case WSP_GGML_TYPE_I8: case WSP_GGML_TYPE_I16: @@ -11213,6 +11578,8 @@ static void wsp_ggml_compute_forward_clamp( case WSP_GGML_TYPE_Q4_K: case WSP_GGML_TYPE_Q5_K: case WSP_GGML_TYPE_Q6_K: + case WSP_GGML_TYPE_IQ2_XXS: + case WSP_GGML_TYPE_IQ2_XS: case WSP_GGML_TYPE_Q8_K: case WSP_GGML_TYPE_I8: case WSP_GGML_TYPE_I16: @@ -11257,7 +11624,22 @@ static float wsp_ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } -void wsp_ggml_rope_yarn_corr_dims( +static void wsp_ggml_rope_cache_init( + float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale +) { + float theta = theta_base; + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + rope_yarn( + theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + ); + cache[i0 + 1] *= sin_sign; + + theta *= theta_scale; + } +} + +WSP_GGML_CALL void wsp_ggml_rope_yarn_corr_dims( int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims @@ -11339,6 +11721,12 @@ static void wsp_ggml_compute_forward_rope_f32( for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { const int64_t p = pos[i2]; + + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox + wsp_ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -11372,18 +11760,13 @@ static void wsp_ggml_compute_forward_rope_f32( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - float cos_theta, sin_theta; - rope_yarn( - theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta - ); - sin_theta *= sin_sign; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; // zeta scaling for xPos only: float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; - theta_base *= theta_scale; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -11395,10 +11778,13 @@ static void wsp_ggml_compute_forward_rope_f32( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11421,6 +11807,14 @@ static void wsp_ggml_compute_forward_rope_f32( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -11496,6 +11890,12 @@ static void wsp_ggml_compute_forward_rope_f16( for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { const int64_t p = pos[i2]; + + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox + wsp_ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -11529,13 +11929,8 @@ static void wsp_ggml_compute_forward_rope_f16( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - float cos_theta, sin_theta; - rope_yarn( - theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta - ); - sin_theta *= sin_sign; - - theta_base *= theta_scale; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -11548,10 +11943,13 @@ static void wsp_ggml_compute_forward_rope_f16( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11574,6 +11972,14 @@ static void wsp_ggml_compute_forward_rope_f16( dst_data[0] = WSP_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[n_dims/2] = WSP_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } else { + const int64_t i0 = ic; + + const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -14182,7 +14588,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st } break; case WSP_GGML_OP_MUL_MAT: { - wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]); + wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); } break; case WSP_GGML_OP_MUL_MAT_ID: { @@ -14194,7 +14600,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st } break; case WSP_GGML_OP_SCALE: { - wsp_ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor); + wsp_ggml_compute_forward_scale(params, tensor->src[0], tensor); } break; case WSP_GGML_OP_SET: { @@ -14489,7 +14895,7 @@ size_t wsp_ggml_hash_find_or_insert(struct wsp_ggml_hash_set hash_set, struct ws return i; } -static struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) { +struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) { size = wsp_ggml_hash_size(size); struct wsp_ggml_hash_set result; result.size = size; @@ -14558,7 +14964,7 @@ static struct wsp_ggml_tensor * wsp_ggml_recompute_graph_node( return replacements->vals[i]; } - struct wsp_ggml_tensor * clone = wsp_ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); + struct wsp_ggml_tensor * clone = wsp_ggml_new_tensor(ctx, node->type, WSP_GGML_MAX_DIMS, node->ne); // insert clone into replacements WSP_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite @@ -14650,7 +15056,7 @@ static struct wsp_ggml_tensor * wsp_ggml_add_or_set(struct wsp_ggml_context * ct static struct wsp_ggml_tensor * wsp_ggml_acc_or_set(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct wsp_ggml_hash_set zero_table) { if (wsp_ggml_hash_contains(zero_table, a)) { - struct wsp_ggml_tensor * a_zero = wsp_ggml_scale(ctx, a, wsp_ggml_new_f32(ctx, 0)); + struct wsp_ggml_tensor * a_zero = wsp_ggml_scale(ctx, a, 0.0f); return wsp_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); } else { return wsp_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); @@ -14786,7 +15192,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ src0->grad, wsp_ggml_scale(ctx, wsp_ggml_mul(ctx, src0, tensor->grad), - wsp_ggml_new_f32(ctx, 2.0f)), + 2.0f), zero_table); } } break; @@ -14800,7 +15206,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ wsp_ggml_div(ctx, tensor->grad, tensor), - wsp_ggml_new_f32(ctx, 0.5f)), + 0.5f), zero_table); } } break; @@ -14966,17 +15372,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ { // necessary for llama if (src0->grad) { + float s; + memcpy(&s, tensor->op_params, sizeof(float)); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, - wsp_ggml_scale_impl(ctx, tensor->grad, src1, false), - zero_table); - } - if (src1->grad) { - src1->grad = - wsp_ggml_add_or_set(ctx, - src1->grad, - wsp_ggml_sum(ctx, wsp_ggml_mul_impl(ctx, tensor->grad, src0, false)), + wsp_ggml_scale_impl(ctx, tensor->grad, s, false), zero_table); } } break; @@ -15154,6 +15556,8 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, + /* wsp_ggml_diag_mask_inf_impl() shouldn't be here */ + /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ wsp_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), zero_table); } @@ -15961,28 +16365,9 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) { //n_tasks = MIN(n_threads, MAX(1, nr0/128)); //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); - -#if defined(WSP_GGML_USE_CUBLAS) - if (wsp_ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - } -#elif defined(WSP_GGML_USE_CLBLAST) - if (wsp_ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - } -#endif -#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) - if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - } -#endif } break; case WSP_GGML_OP_MUL_MAT_ID: { - // FIXME: blas n_tasks = n_threads; } break; case WSP_GGML_OP_OUT_PROD: @@ -16152,6 +16537,7 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { state->shared->node_n += 1; return (thread_ret_t) WSP_GGML_EXIT_ABORTED; } + if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { // all other threads are finished and spinning // do finalize and init here so we don't have synchronize again @@ -16217,14 +16603,18 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { } else { // wait for other threads to finish const int last = node_n; + + const bool do_yield = last < 0 || cgraph->nodes[last]->op == WSP_GGML_OP_MUL_MAT; + while (true) { // TODO: this sched_yield can have significant impact on the performance - either positive or negative // depending on the workload and the operating system. // since it is not clear what is the best approach, it should potentially become user-configurable // ref: https://github.com/ggerganov/ggml/issues/291 -#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) - sched_yield(); -#endif + // UPD: adding the do_yield flag seems to resolve the issue universally + if (do_yield) { + sched_yield(); + } node_n = atomic_load(&state->shared->node_n); if (node_n != last) break; @@ -16254,7 +16644,7 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { return WSP_GGML_EXIT_SUCCESS; } -struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n_threads) { +struct wsp_ggml_cplan wsp_ggml_graph_plan(const struct wsp_ggml_cgraph * cgraph, int n_threads) { if (n_threads <= 0) { n_threads = WSP_GGML_DEFAULT_N_THREADS; } @@ -16303,7 +16693,7 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n } else #endif #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) - if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { + if (wsp_ggml_compute_forward_mul_mat_use_blas(node)) { if (node->src[0]->type != WSP_GGML_TYPE_F32) { // here we need memory just for single 2D matrix from src0 cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]); @@ -16311,25 +16701,22 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n } else #endif if (node->src[1]->type != vec_dot_type) { - cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(node->src[1])/wsp_ggml_blck_size(vec_dot_type); + cur = wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(node->src[1])); } } break; case WSP_GGML_OP_MUL_MAT_ID: { - const struct wsp_ggml_tensor * a = node->src[2]; - const struct wsp_ggml_tensor * b = node->src[1]; - const enum wsp_ggml_type vec_dot_type = type_traits[a->type].vec_dot_type; -#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) - if (wsp_ggml_compute_forward_mul_mat_use_blas(a, b, node)) { - if (a->type != WSP_GGML_TYPE_F32) { - // here we need memory just for single 2D matrix from src0 - cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(a->ne[0]*a->ne[1]); - } - } else -#endif - if (b->type != vec_dot_type) { - cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(b)/wsp_ggml_blck_size(vec_dot_type); + cur = 0; + const struct wsp_ggml_tensor * src0 = node->src[2]; + const struct wsp_ggml_tensor * src1 = node->src[1]; + const enum wsp_ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; + if (src1->type != vec_dot_type) { + cur += wsp_ggml_row_size(vec_dot_type, wsp_ggml_nelements(src1)); } + const int n_as = wsp_ggml_get_op_params_i32(node, 1); + cur += WSP_GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows } break; case WSP_GGML_OP_OUT_PROD: { @@ -16338,6 +16725,7 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n } } break; case WSP_GGML_OP_SOFT_MAX: + case WSP_GGML_OP_ROPE: { cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks; } break; @@ -16559,7 +16947,7 @@ static void wsp_ggml_graph_export_leaf(const struct wsp_ggml_tensor * tensor, FI fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", wsp_ggml_type_name(tensor->type), wsp_ggml_op_name (tensor->op), - tensor->n_dims, + wsp_ggml_n_dims(tensor), ne[0], ne[1], ne[2], ne[3], nb[0], nb[1], nb[2], nb[3], tensor->data, @@ -16574,7 +16962,7 @@ static void wsp_ggml_graph_export_node(const struct wsp_ggml_tensor * tensor, co arg, wsp_ggml_type_name(tensor->type), wsp_ggml_op_name (tensor->op), - tensor->n_dims, + wsp_ggml_n_dims(tensor), ne[0], ne[1], ne[2], ne[3], nb[0], nb[1], nb[2], nb[3], tensor->data, @@ -16664,11 +17052,9 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f const uint32_t type = tensor->type; const uint32_t op = tensor->op; - const uint32_t n_dims = tensor->n_dims; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&n_dims, sizeof(uint32_t), 1, fout); for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -16698,11 +17084,9 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f const uint32_t type = tensor->type; const uint32_t op = tensor->op; - const uint32_t n_dims = tensor->n_dims; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&n_dims, sizeof(uint32_t), 1, fout); for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -16874,12 +17258,10 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg { uint32_t type; uint32_t op; - uint32_t n_dims; for (uint32_t i = 0; i < n_leafs; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); - n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); int64_t ne[WSP_GGML_MAX_DIMS]; size_t nb[WSP_GGML_MAX_DIMS]; @@ -16895,7 +17277,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg nb[j] = nb_cur; } - struct wsp_ggml_tensor * tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, n_dims, ne); + struct wsp_ggml_tensor * tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, WSP_GGML_MAX_DIMS, ne); tensor->op = (enum wsp_ggml_op) op; @@ -16912,7 +17294,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg ptr += wsp_ggml_nbytes(tensor); - fprintf(stderr, "%s: loaded leaf %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, wsp_ggml_nbytes(tensor)); + fprintf(stderr, "%s: loaded leaf %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, wsp_ggml_nbytes(tensor)); } } @@ -16922,12 +17304,10 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg { uint32_t type; uint32_t op; - uint32_t n_dims; for (uint32_t i = 0; i < n_nodes; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); - n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); enum wsp_ggml_op eop = (enum wsp_ggml_op) op; @@ -16998,7 +17378,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg } break; default: { - tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, n_dims, ne); + tensor = wsp_ggml_new_tensor(*ctx_eval, (enum wsp_ggml_type) type, WSP_GGML_MAX_DIMS, ne); tensor->op = eop; } break; @@ -17017,7 +17397,7 @@ struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_gg result->nodes[i] = tensor; - fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, wsp_ggml_nbytes(tensor)); + fprintf(stderr, "%s: loaded node %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, wsp_ggml_nbytes(tensor)); } } } @@ -17155,7 +17535,7 @@ void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp fprintf(fp, "(%s)|", wsp_ggml_type_name(node->type)); } - if (node->n_dims == 2) { + if (wsp_ggml_is_matrix(node)) { fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], wsp_ggml_op_symbol(node->op)); } else { fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], wsp_ggml_op_symbol(node->op)); @@ -17284,9 +17664,9 @@ static void wsp_ggml_opt_acc_grad(int np, struct wsp_ggml_tensor * const ps[], f } // -// ADAM +// Using AdamW - ref: https://arxiv.org/pdf/1711.05101v3.pdf // -// ref: https://arxiv.org/pdf/1412.6980.pdf +// (Original Adam - ref: https://arxiv.org/pdf/1412.6980.pdf) // static enum wsp_ggml_opt_result wsp_ggml_opt_adam( @@ -17422,7 +17802,7 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_adam( int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = wsp_ggml_nelements(ps[p]); - const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; + const float p_decay = ((wsp_ggml_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched; for (int64_t j = 0; j < ne; ++j) { float x = wsp_ggml_get_f32_1d(ps[p], j); float g_ = g[i]*gnorm; @@ -18144,6 +18524,28 @@ enum wsp_ggml_opt_result wsp_ggml_opt_resume_g( //////////////////////////////////////////////////////////////////////////////// +void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type) { + wsp_ggml_critical_section_start(); + + switch (type) { + case WSP_GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break; + case WSP_GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break; + default: // nothing + break; + } + + wsp_ggml_critical_section_end(); +} + +void wsp_ggml_wsp_quantize_free(void) { + wsp_ggml_critical_section_start(); + + iq2xs_free_impl(256); + iq2xs_free_impl(512); + + wsp_ggml_critical_section_end(); +} + size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK4_0 == 0); const int nb = k / QK4_0; @@ -18271,32 +18673,53 @@ size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, i return (n/QK8_0*sizeof(block_q8_0)); } -size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { +bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type) { + return + type == WSP_GGML_TYPE_IQ2_XXS || + type == WSP_GGML_TYPE_IQ2_XS; +} + +size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, + int nrows, int n_per_row, int64_t * hist, const float * imatrix) { + wsp_ggml_wsp_quantize_init(type); // this is noop if already initialized size_t result = 0; + int n = nrows * n_per_row; switch (type) { case WSP_GGML_TYPE_Q4_0: { WSP_GGML_ASSERT(start % QK4_0 == 0); - block_q4_0 * block = (block_q4_0*)dst + start / QK4_0; - result = wsp_ggml_wsp_quantize_q4_0(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q4_0(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q4_1: { WSP_GGML_ASSERT(start % QK4_1 == 0); - block_q4_1 * block = (block_q4_1*)dst + start / QK4_1; - result = wsp_ggml_wsp_quantize_q4_1(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q4_1(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q5_0: { WSP_GGML_ASSERT(start % QK5_0 == 0); - block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; - result = wsp_ggml_wsp_quantize_q5_0(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q5_0(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q5_1: { WSP_GGML_ASSERT(start % QK5_1 == 0); - block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; - result = wsp_ggml_wsp_quantize_q5_1(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q5_1(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q8_0: { @@ -18307,42 +18730,77 @@ size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, v case WSP_GGML_TYPE_Q2_K: { WSP_GGML_ASSERT(start % QK_K == 0); - block_q2_K * block = (block_q2_K*)dst + start / QK_K; - result = wsp_ggml_wsp_quantize_q2_K(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q2_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q3_K: { WSP_GGML_ASSERT(start % QK_K == 0); - block_q3_K * block = (block_q3_K*)dst + start / QK_K; - result = wsp_ggml_wsp_quantize_q3_K(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q3_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q4_K: { WSP_GGML_ASSERT(start % QK_K == 0); - block_q4_K * block = (block_q4_K*)dst + start / QK_K; - result = wsp_ggml_wsp_quantize_q4_K(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q4_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q5_K: { WSP_GGML_ASSERT(start % QK_K == 0); - block_q5_K * block = (block_q5_K*)dst + start / QK_K; - result = wsp_ggml_wsp_quantize_q5_K(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q5_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_Q6_K: { WSP_GGML_ASSERT(start % QK_K == 0); - block_q6_K * block = (block_q6_K*)dst + start / QK_K; - result = wsp_ggml_wsp_quantize_q6_K(src + start, block, n, n, hist); + WSP_GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_q6_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); + } break; + case WSP_GGML_TYPE_IQ2_XXS: + { + WSP_GGML_ASSERT(start % QK_K == 0); + WSP_GGML_ASSERT(start % n_per_row == 0); + WSP_GGML_ASSERT(imatrix); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_iq2_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); + } break; + case WSP_GGML_TYPE_IQ2_XS: + { + WSP_GGML_ASSERT(start % QK_K == 0); + WSP_GGML_ASSERT(start % n_per_row == 0); + WSP_GGML_ASSERT(imatrix); + size_t start_row = start / n_per_row; + size_t row_size = wsp_ggml_row_size(type, n_per_row); + result = wsp_quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + WSP_GGML_ASSERT(result == row_size * nrows); } break; case WSP_GGML_TYPE_F16: { - int elemsize = sizeof(wsp_ggml_fp16_t); + size_t elemsize = sizeof(wsp_ggml_fp16_t); wsp_ggml_fp32_to_fp16_row(src + start, (wsp_ggml_fp16_t *)dst + start, n); result = n * elemsize; } break; case WSP_GGML_TYPE_F32: { - int elemsize = sizeof(float); + size_t elemsize = sizeof(float); result = n * elemsize; memcpy((uint8_t *)dst + start * elemsize, src + start, result); } break; @@ -18689,14 +19147,14 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp (int64_t) info->ne[3]; if (ne % wsp_ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", - __func__, info->name.data, ne, wsp_ggml_blck_size(info->type)); + fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", + __func__, info->name.data, (int)info->type, wsp_ggml_type_name(info->type), ne, wsp_ggml_blck_size(info->type)); fclose(file); wsp_gguf_free(ctx); return NULL; } - const size_t size_cur = (ne*wsp_ggml_type_size(info->type))/wsp_ggml_blck_size(info->type); + const size_t size_cur = wsp_ggml_row_size(info->type, ne); ctx->size += WSP_GGML_PAD(size_cur, ctx->alignment); } @@ -18796,7 +19254,7 @@ void wsp_gguf_free(struct wsp_gguf_context * ctx) { if (ctx->kv) { // free string memory - not great.. - for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { struct wsp_gguf_kv * kv = &ctx->kv[i]; if (kv->key.data) { @@ -18812,7 +19270,7 @@ void wsp_gguf_free(struct wsp_gguf_context * ctx) { if (kv->type == WSP_GGUF_TYPE_ARRAY) { if (kv->value.arr.data) { if (kv->value.arr.type == WSP_GGUF_TYPE_STRING) { - for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + for (uint64_t j = 0; j < kv->value.arr.n; ++j) { struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[j]; if (str->data) { free(str->data); @@ -18828,7 +19286,7 @@ void wsp_gguf_free(struct wsp_gguf_context * ctx) { } if (ctx->infos) { - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { struct wsp_gguf_tensor_info * info = &ctx->infos[i]; if (info->name.data) { @@ -19025,6 +19483,10 @@ char * wsp_gguf_get_tensor_name(const struct wsp_gguf_context * ctx, int i) { return ctx->infos[i].name.data; } +enum wsp_ggml_type wsp_gguf_get_tensor_type(const struct wsp_gguf_context * ctx, int i) { + return ctx->infos[i].type; +} + // returns the index static int wsp_gguf_get_or_add_key(struct wsp_gguf_context * ctx, const char * key) { const int idx = wsp_gguf_find_key(ctx, key); @@ -19175,7 +19637,7 @@ void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * sr data[j] = ((struct wsp_gguf_str *)src->kv[i].value.arr.data)[j].data; } wsp_gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); - free(data); + free((void *)data); } else if (src->kv[i].value.arr.type == WSP_GGUF_TYPE_ARRAY) { WSP_GGML_ASSERT(false && "nested arrays not supported"); } else { @@ -19200,8 +19662,8 @@ void wsp_gguf_add_tensor( ctx->infos[idx].ne[i] = 1; } - ctx->infos[idx].n_dims = tensor->n_dims; - for (int i = 0; i < tensor->n_dims; i++) { + ctx->infos[idx].n_dims = wsp_ggml_n_dims(tensor); + for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) { ctx->infos[idx].ne[i] = tensor->ne[i]; } @@ -19465,6 +19927,14 @@ int wsp_ggml_cpu_has_avx(void) { #endif } +int wsp_ggml_cpu_has_avx_vnni(void) { +#if defined(__AVXVNNI__) + return 1; +#else + return 0; +#endif +} + int wsp_ggml_cpu_has_avx2(void) { #if defined(__AVX2__) return 1; diff --git a/cpp/ggml.h b/cpp/ggml.h index ca23863..2e2dbfc 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -187,6 +187,16 @@ # define WSP_GGML_API #endif +#ifdef WSP_GGML_MULTIPLATFORM +# if defined(_WIN32) +# define WSP_GGML_CALL +# else +# define WSP_GGML_CALL __attribute__((__ms_abi__)) +# endif +#else +# define WSP_GGML_CALL +#endif + // TODO: support for clang #ifdef __GNUC__ # define WSP_GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) @@ -218,7 +228,9 @@ #define WSP_GGML_MAX_PARAMS 2048 #define WSP_GGML_MAX_CONTEXTS 64 #define WSP_GGML_MAX_SRC 10 +#ifndef WSP_GGML_MAX_NAME #define WSP_GGML_MAX_NAME 64 +#endif #define WSP_GGML_MAX_OP_PARAMS 64 #define WSP_GGML_DEFAULT_N_THREADS 4 #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048 @@ -255,6 +267,8 @@ #define WSP_GGML_UNREACHABLE() WSP_GGML_ASSERT(!"statement should not be reached") #elif defined(__GNUC__) #define WSP_GGML_UNREACHABLE() __builtin_unreachable() +#elif defined(_MSC_VER) +#define WSP_GGML_UNREACHABLE() __assume(0) #else #define WSP_GGML_UNREACHABLE() ((void) 0) #endif @@ -303,7 +317,7 @@ extern "C" { #if defined(__ARM_NEON) && defined(__CUDACC__) typedef half wsp_ggml_fp16_t; -#elif defined(__ARM_NEON) +#elif defined(__ARM_NEON) && !defined(_MSC_VER) typedef __fp16 wsp_ggml_fp16_t; #else typedef uint16_t wsp_ggml_fp16_t; @@ -337,12 +351,20 @@ extern "C" { WSP_GGML_TYPE_Q5_K = 13, WSP_GGML_TYPE_Q6_K = 14, WSP_GGML_TYPE_Q8_K = 15, + WSP_GGML_TYPE_IQ2_XXS = 16, + WSP_GGML_TYPE_IQ2_XS = 17, WSP_GGML_TYPE_I8, WSP_GGML_TYPE_I16, WSP_GGML_TYPE_I32, WSP_GGML_TYPE_COUNT, }; + // precision + enum wsp_ggml_prec { + WSP_GGML_PREC_DEFAULT, + WSP_GGML_PREC_F32, + }; + enum wsp_ggml_backend_type { WSP_GGML_BACKEND_CPU = 0, WSP_GGML_BACKEND_GPU = 10, @@ -365,6 +387,8 @@ extern "C" { WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors + WSP_GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors }; // available tensor operations: @@ -478,7 +502,8 @@ extern "C" { enum wsp_ggml_log_level { WSP_GGML_LOG_LEVEL_ERROR = 2, WSP_GGML_LOG_LEVEL_WARN = 3, - WSP_GGML_LOG_LEVEL_INFO = 4 + WSP_GGML_LOG_LEVEL_INFO = 4, + WSP_GGML_LOG_LEVEL_DEBUG = 5 }; // ggml object @@ -502,7 +527,6 @@ extern "C" { struct wsp_ggml_backend_buffer * buffer; - int n_dims; int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes: // nb[0] = wsp_ggml_type_size(type) @@ -534,7 +558,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[12]; + char padding[8]; }; static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor); @@ -635,33 +659,41 @@ extern "C" { WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj); WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx); - WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor); - WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor); - WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor); - WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN - WSP_GGML_API size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split); + WSP_GGML_API WSP_GGML_CALL int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API WSP_GGML_CALL int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN - WSP_GGML_API int wsp_ggml_blck_size (enum wsp_ggml_type type); - WSP_GGML_API size_t wsp_ggml_type_size (enum wsp_ggml_type type); // size in bytes for all elements in a block - WSP_GGML_API float wsp_ggml_type_sizef(enum wsp_ggml_type type); // wsp_ggml_type_size()/wsp_ggml_blck_size() as float + WSP_GGML_API WSP_GGML_CALL int wsp_ggml_blck_size(enum wsp_ggml_type type); + WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block + WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row - WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type); - WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op); - WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op); + WSP_GGML_DEPRECATED( + WSP_GGML_API double wsp_ggml_type_sizef(enum wsp_ggml_type type), // wsp_ggml_type_size()/wsp_ggml_blck_size() as float + "use wsp_ggml_row_size() instead"); - WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op); - WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name + WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_type_name(enum wsp_ggml_type type); + WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_op_name (enum wsp_ggml_op op); + WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op); - WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor); + WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op); + WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name - WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type); + WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor); + + WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_quantized(enum wsp_ggml_type type); // TODO: temporary until model loading of ggml examples is refactored WSP_GGML_API enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype); - WSP_GGML_API bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor); - WSP_GGML_API bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor); - WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor); + WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor); + WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor); + WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars WSP_GGML_API bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1); @@ -722,8 +754,8 @@ extern "C" { WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src); // Context tensor enumeration and lookup - WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx); - WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(const struct wsp_ggml_context * ctx); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor); @@ -748,7 +780,7 @@ extern "C" { WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor); WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor); - WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor); + WSP_GGML_API WSP_GGML_CALL enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor); WSP_GGML_API const char * wsp_ggml_get_name (const struct wsp_ggml_tensor * tensor); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name ( struct wsp_ggml_tensor * tensor, const char * name); @@ -1050,6 +1082,12 @@ extern "C" { struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b); + // change the precision of a matrix multiplication + // set to WSP_GGML_PREC_F32 for higher precision (useful for phi-2) + WSP_GGML_API void wsp_ggml_mul_mat_set_prec( + struct wsp_ggml_tensor * a, + enum wsp_ggml_prec prec); + // indirect matrix multiplication // wsp_ggml_mul_mat_id(ctx, as, ids, id, b) ~= wsp_ggml_mul_mat(as[ids[id]], b) WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id( @@ -1075,13 +1113,13 @@ extern "C" { WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b); + float s); // in-place, returns view(a) WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b); + float s); // b -> view(a,offset,nb1,nb2,3), return modified a WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set( @@ -1137,22 +1175,16 @@ extern "C" { struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b); - // a -> b, in-place, return view(b) - WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy_inplace( + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cast( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b); + enum wsp_ggml_type type); // make contiguous WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a); - // make contiguous, in-place - WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_inplace( - struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a); - // make contiguous, with new shape WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d( struct wsp_ggml_context * ctx, @@ -1391,7 +1423,7 @@ extern "C" { float beta_slow); // compute correction dims for YaRN RoPE scaling - void wsp_ggml_rope_yarn_corr_dims( + WSP_GGML_CALL void wsp_ggml_rope_yarn_corr_dims( int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); // xPos RoPE, in-place, returns view(a) @@ -1825,8 +1857,8 @@ extern "C" { // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data - WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/); - WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan); + WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (const struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/); + WSP_GGML_API int wsp_ggml_graph_compute( struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan); // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data @@ -2033,6 +2065,18 @@ extern "C" { // quantization // + // - wsp_ggml_wsp_quantize_init can be called multiple times with the same type + // it will only initialize the quantization tables for the first call or after wsp_ggml_wsp_quantize_free + // automatically called by wsp_ggml_wsp_quantize_chunk for convenience + // + // - wsp_ggml_wsp_quantize_free will free any memory allocated by wsp_ggml_wsp_quantize_init + // call this at the end of the program to avoid memory leaks + // + // note: these are thread-safe + // + WSP_GGML_API void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type); + WSP_GGML_API void wsp_ggml_wsp_quantize_free(void); + // TODO: these would probably get removed in favor of the more general wsp_ggml_wsp_quantize_chunk WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); @@ -2046,7 +2090,12 @@ extern "C" { WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); WSP_GGML_API size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); - WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + // some quantization type cannot be used without an importance matrix + WSP_GGML_API bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type); + + // calls wsp_ggml_wsp_quantize_init internally (i.e. can allocate memory) + WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, + int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix); // // gguf @@ -2116,10 +2165,11 @@ extern "C" { WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id); WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i); - WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx); - WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name); - WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i); + WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx); + WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name); + WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i); + WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i); + WSP_GGML_API enum wsp_ggml_type wsp_gguf_get_tensor_type (const struct wsp_gguf_context * ctx, int i); // overrides existing values or adds a new one WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val); @@ -2175,6 +2225,7 @@ extern "C" { // WSP_GGML_API int wsp_ggml_cpu_has_avx (void); + WSP_GGML_API int wsp_ggml_cpu_has_avx_vnni (void); WSP_GGML_API int wsp_ggml_cpu_has_avx2 (void); WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void); WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void); diff --git a/cpp/whisper.cpp b/cpp/whisper.cpp index 9e1befb..2fd513c 100644 --- a/cpp/whisper.cpp +++ b/cpp/whisper.cpp @@ -122,9 +122,18 @@ WHISPER_ATTRIBUTE_FORMAT(2, 3) static void whisper_log_internal (wsp_ggml_log_level level, const char * format, ...); static void whisper_log_callback_default(wsp_ggml_log_level level, const char * text, void * user_data); -#define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__) -#define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define WHISPER_LOG_ERROR(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define WHISPER_LOG_WARN(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define WHISPER_LOG_INFO(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define WHISPER_DEBUG + +#if defined(WHISPER_DEBUG) +#define WHISPER_LOG_DEBUG(...) whisper_log_internal(WSP_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define WHISPER_LOG_DEBUG(...) +#endif #define WHISPER_ASSERT(x) \ do { \ @@ -134,18 +143,6 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char * } \ } while (0) -// define this to enable verbose trace logging - useful for debugging purposes -//#define WHISPER_DEBUG - -#if defined(WHISPER_DEBUG) -#define WHISPER_PRINT_DEBUG(...) \ - do { \ - fprintf(stderr, __VA_ARGS__); \ - } while (0) -#else -#define WHISPER_PRINT_DEBUG(...) -#endif - //#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 8 @@ -155,7 +152,7 @@ static void whisper_log_callback_default(wsp_ggml_log_level level, const char * // ggml helpers // -static void wsp_ggml_graph_compute_helper( +static bool wsp_ggml_graph_compute_helper( struct wsp_ggml_cgraph * graph, std::vector & buf, int n_threads, @@ -171,10 +168,10 @@ static void wsp_ggml_graph_compute_helper( plan.work_data = buf.data(); } - wsp_ggml_graph_compute(graph, &plan); + return wsp_ggml_graph_compute(graph, &plan); } -static void wsp_ggml_graph_compute_helper( +static bool wsp_ggml_graph_compute_helper( struct wsp_ggml_backend * backend, struct wsp_ggml_cgraph * graph, int n_threads) { @@ -186,7 +183,7 @@ static void wsp_ggml_graph_compute_helper( wsp_ggml_backend_metal_set_n_cb(backend, n_threads); } #endif - wsp_ggml_backend_graph_compute(backend, graph); + return wsp_ggml_backend_graph_compute(backend, graph); } // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" @@ -487,8 +484,8 @@ static size_t whisper_allocr_size(struct whisper_allocr & allocr) { // measure the memory usage of a graph and prepare the allocr's internal data buffer static void whisper_allocr_graph_init(struct whisper_allocr & allocr, wsp_ggml_backend_t backend, std::function && get_graph) { - auto & alloc = allocr.alloc; - auto & meta = allocr.meta; + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; alloc = wsp_ggml_allocr_new_measure_from_backend(backend); @@ -704,7 +701,7 @@ struct whisper_model { struct wsp_ggml_context * ctx; // the model backend data is read-only and can be shared between processors - struct wsp_ggml_backend_buffer * buffer; + std::vector buffers; // tensors int n_loaded; @@ -1073,7 +1070,7 @@ static wsp_ggml_backend_t whisper_backend_init(const whisper_context_params & pa #ifdef WSP_GGML_USE_METAL if (params.use_gpu) { WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); - wsp_ggml_metal_log_set_callback(whisper_log_callback_default, nullptr); + wsp_ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); backend_gpu = wsp_ggml_backend_metal_init(); if (!backend_gpu) { WHISPER_LOG_ERROR("%s: wsp_ggml_backend_metal_init() failed\n", __func__); @@ -1517,24 +1514,64 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.backend = whisper_backend_init(wctx.params); + // some devices have a limit on the maximum size of single memory buffer + // for example, iPhones are limited to 1GB per buffer + // to workaround this, we will allocate multiple buffers of smaller size and will split the tensors with the + // model weights between them + // + // the map_t2b maps tensor names to buffer indices + // as we iterate over the tensors, we will allocate new buffers when the current one is full + // + // finally, we create a separate allocator for each buffer and use it to allocate the tensors + // we keep the allocators alive until all the tensors are loaded + + WSP_GGML_ASSERT(model.buffers.empty()); + + std::map map_t2b; + { size_t size_main = 0; + size_t size_cur = 0; + + static const size_t GB = 1024ull*1024ull*1024ull; for (const auto & t : model.tensors) { - size_main += wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead(); + const size_t cur = wsp_ggml_nbytes(t.second) + wsp_ggml_tensor_overhead(); + + // adding the tensor to the current buffer will exceed the limit, so we need to allocate a new buffer + if (size_cur + cur > GB) { + WSP_GGML_ASSERT(size_cur > 0 && "A tensor is too large to fit in a single buffer"); + + model.buffers.emplace_back(wsp_ggml_backend_alloc_buffer(wctx.backend, size_cur)); + + size_cur = cur; + } + + map_t2b[t.first] = model.buffers.size(); + + size_cur += cur; + size_main += cur; } - model.buffer = wsp_ggml_backend_alloc_buffer(wctx.backend, size_main); + // allocate the last buffer if needed + if (size_cur > 0) { + model.buffers.emplace_back(wsp_ggml_backend_alloc_buffer(wctx.backend, size_cur)); + } - WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6); + WSP_GGML_ASSERT(model.buffers.size() > 0); + + WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB (%d buffers)\n", __func__, wsp_ggml_backend_name(wctx.backend), size_main / 1e6, (int) model.buffers.size()); } - wsp_ggml_allocr * alloc = wsp_ggml_allocr_new_from_buffer(model.buffer); + std::vector allocs(model.buffers.size()); + for (size_t i = 0; i < allocs.size(); ++i) { + allocs[i] = wsp_ggml_allocr_new_from_buffer(model.buffers[i]); + } // allocate tensors in the backend buffers { for (const auto & t : model.tensors) { - wsp_ggml_allocr_alloc(alloc, t.second); + wsp_ggml_allocr_alloc(allocs[map_t2b[t.first]], t.second); } } @@ -1635,7 +1672,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - wsp_ggml_allocr_free(alloc); + for (auto & alloc : allocs) { + wsp_ggml_allocr_free(alloc); + } wctx.t_load_us = wsp_ggml_time_us() - t_start_us; @@ -1777,7 +1816,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder( wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); - wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc; + //wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc; //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_ctx, n_state); //wsp_ggml_allocr_alloc(alloc, cur); @@ -1787,13 +1826,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder( //} struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_conv); - struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1); - wsp_ggml_allocr_alloc(alloc, KQscale); - - if (!wsp_ggml_allocr_is_measure(alloc)) { - const float val = 1.0f/sqrtf(float(n_state)/n_head); - wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); - } + const float KQscale = 1.0f/sqrtf(float(n_state)/n_head); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1843,14 +1876,14 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder( Qcur = wsp_ggml_add(ctx0, Qcur, layer.attn_q_b); - //Qcur = wsp_ggml_scale(ctx0, Qcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + //Qcur = wsp_ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25)); // note: no bias for Key struct wsp_ggml_tensor * Kcur = wsp_ggml_mul_mat(ctx0, layer.attn_k_w, cur); - //Kcur = wsp_ggml_scale(ctx0, Kcur, wsp_ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + //Kcur = wsp_ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25)); struct wsp_ggml_tensor * Vcur = wsp_ggml_mul_mat(ctx0, layer.attn_v_w, @@ -2032,7 +2065,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross( wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0); - wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc; + //wsp_ggml_allocr * alloc = wstate.alloc_cross.alloc; //struct wsp_ggml_tensor * cur = wsp_ggml_new_tensor_2d(ctx0, WSP_GGML_TYPE_F32, n_state, n_ctx); //wsp_ggml_allocr_alloc(alloc, cur); @@ -2042,13 +2075,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_cross( //} struct wsp_ggml_tensor * cur = wsp_ggml_view_tensor(ctx0, wstate.embd_enc); - struct wsp_ggml_tensor * Kscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1); - wsp_ggml_allocr_alloc(alloc, Kscale); - - if (!wsp_ggml_allocr_is_measure(alloc)) { - const float val = pow(float(n_state) / n_head, -0.25); - wsp_ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float)); - } + const float Kscale = pow(float(n_state) / n_head, -0.25); for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; @@ -2118,7 +2145,9 @@ static bool whisper_encode_internal( wsp_ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); + if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } } } @@ -2132,7 +2161,9 @@ static bool whisper_encode_internal( wsp_ggml_allocr_alloc_graph(alloc, gf); - wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); + if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } } // cross @@ -2145,7 +2176,9 @@ static bool whisper_encode_internal( wsp_ggml_allocr_alloc_graph(alloc, gf); - wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); + if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } } wstate.t_encode_us += wsp_ggml_time_us() - t_start_us; @@ -2178,7 +2211,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder( const int32_t n_kv = wsp_ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n; const int32_t kv_head = wsp_ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head; - //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); + //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); struct wsp_ggml_init_params params = { /*.mem_size =*/ wstate.alloc_decode.meta.size(), @@ -2207,13 +2240,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder( } } - struct wsp_ggml_tensor * KQscale = wsp_ggml_new_tensor_1d(ctx0, WSP_GGML_TYPE_F32, 1); - wsp_ggml_allocr_alloc(alloc, KQscale); - - if (!wsp_ggml_allocr_is_measure(alloc)) { - const float val = pow(float(n_state)/n_head, -0.25); - wsp_ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); - } + const float KQscale = pow(float(n_state)/n_head, -0.25); struct wsp_ggml_tensor * KQ_mask = wsp_ggml_new_tensor_3d(ctx0, WSP_GGML_TYPE_F32, n_kv, n_tokens, 1); wsp_ggml_allocr_alloc(alloc, KQ_mask); @@ -2573,7 +2600,9 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads); + if (!wsp_ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } } logits_out.resize(n_tokens*n_vocab); @@ -3393,8 +3422,10 @@ void whisper_free(struct whisper_context * ctx) { wsp_ggml_free(ctx->model.ctx); } - if (ctx->model.buffer) { - wsp_ggml_backend_buffer_free(ctx->model.buffer); + for (auto & buffer : ctx->model.buffers) { + if (buffer) { + wsp_ggml_backend_buffer_free(buffer); + } } whisper_free_state(ctx->state); @@ -3838,6 +3869,7 @@ void whisper_reset_timings(struct whisper_context * ctx) { ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; ctx->state->t_decode_us = 0; + ctx->state->t_batchd_us = 0; ctx->state->t_prompt_us = 0; ctx->state->n_sample = 0; ctx->state->n_encode = 0; @@ -4966,7 +4998,7 @@ static void whisper_sequence_score( const auto p = kv.second/(double)cnt; entropy -= p*log(p); - //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + //WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); } sequence.entropy = entropy; @@ -5032,7 +5064,7 @@ int whisper_full_with_state( // basically don't process anything that is less than 1.0s // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { - WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10); + WHISPER_LOG_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10); return 0; } @@ -5221,7 +5253,7 @@ int whisper_full_with_state( n_decoders_cur = std::max(1, n_decoders_cur); - WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); + WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { @@ -5265,11 +5297,11 @@ int whisper_full_with_state( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); // print the prompt - WHISPER_PRINT_DEBUG("\n\n"); + WHISPER_LOG_DEBUG("\n\n"); for (int i = 0; i < (int) prompt.size(); i++) { - WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); } - WHISPER_PRINT_DEBUG("\n\n"); + WHISPER_LOG_DEBUG("\n\n"); whisper_kv_cache_clear(state->kv_self); @@ -5417,7 +5449,7 @@ int whisper_full_with_state( whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); - WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); } @@ -5460,7 +5492,7 @@ int whisper_full_with_state( // do not allow to go back in time if (has_ts && seek_delta > seek_delta_new && result_len < i) { - WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new); + WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new); failed = true; // TODO: maybe this is not a failure ? continue; } @@ -5475,7 +5507,7 @@ int whisper_full_with_state( #ifdef WHISPER_DEBUG { const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; - WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); } #endif @@ -5485,22 +5517,22 @@ int whisper_full_with_state( (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached ) { - if (result_len == 0) { + if (result_len == 0 && !params.no_timestamps) { if (seek + seek_delta + 100 >= seek_end) { result_len = i + 1; } else { - WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j); + WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j); failed = true; continue; } } - if (params.single_segment) { + if (params.single_segment || params.no_timestamps) { result_len = i + 1; seek_delta = 100*WHISPER_CHUNK_SIZE; } - WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j); + WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j); completed = true; continue; } @@ -5516,7 +5548,7 @@ int whisper_full_with_state( // sometimes, the decoding can get stuck in a repetition loop // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { - WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j); + WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j); failed = true; continue; } @@ -5558,7 +5590,7 @@ int whisper_full_with_state( continue; } - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); + //WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); decoder.i_batch = batch.n_tokens; @@ -5638,11 +5670,11 @@ int whisper_full_with_state( decoder.sequence.tokens.resize(decoder.sequence.result_len); whisper_sequence_score(params, decoder.sequence); - WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { - WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; @@ -5657,7 +5689,7 @@ int whisper_full_with_state( } } - WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + WHISPER_LOG_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); } bool success = true; @@ -5669,7 +5701,7 @@ int whisper_full_with_state( const auto & decoder = state->decoders[best_decoder_id]; if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { - WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold); + WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold); success = false; state->n_fail_p++; } @@ -5677,13 +5709,13 @@ int whisper_full_with_state( if (success) { //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { - // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + // WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); //} break; } - WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + WHISPER_LOG_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } // output results through a user-provided callback @@ -5695,7 +5727,7 @@ int whisper_full_with_state( const auto & tokens_cur = best_decoder.sequence.tokens; - //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); // update prompt_past prompt_past.clear(); @@ -5815,7 +5847,7 @@ int whisper_full_with_state( // update audio window seek += seek_delta; - WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); } } @@ -6132,7 +6164,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { // multi-thread - for (uint32_t k = 1; k <= n_threads; k++) { + for (int32_t k = 1; k <= n_threads; k++) { char * src = (char *) malloc(size); char * dst = (char *) malloc(size); @@ -6156,13 +6188,13 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { const int64_t t0 = wsp_ggml_time_us(); std::vector threads(k - 1); - for (uint32_t th = 0; th < k - 1; ++th) { + for (int32_t th = 0; th < k - 1; ++th) { threads[th] = std::thread(helper, th); } helper(k - 1); - for (uint32_t th = 0; th < k - 1; ++th) { + for (int32_t th = 0; th < k - 1; ++th) { threads[th].join(); } diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index 8f25231..3701bda 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -1110,7 +1110,7 @@ PODS: - SSZipArchive (~> 2.2) - SocketRocket (0.6.1) - SSZipArchive (2.4.3) - - whisper-rn (0.4.0-rc.6): + - whisper-rn (0.4.0-rc.7): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1357,7 +1357,7 @@ SPEC CHECKSUMS: RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801 SocketRocket: f32cd54efbe0f095c4d7594881e52619cfe80b17 SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef - whisper-rn: ad8b7033093d163d7516ac75ee9aaf1a4744d7b1 + whisper-rn: fee681f3a7f168c6f615879dd8bf4bc002c1e335 Yoga: e64aa65de36c0832d04e8c7bd614396c77a80047 PODFILE CHECKSUM: 80498ebe77b4fb086f5bdde47d6aa337318b67f0 diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 3d5c6fb..21ae6d6 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -76,7 +76,6 @@ yarn example patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch patch -p0 -d ./cpp < ./scripts/whisper.h.patch patch -p0 -d ./cpp < ./scripts/whisper.cpp.patch -patch -p0 -d ./cpp/coreml < ./scripts/whisper-encoder.mm.patch # Download model for example cd whisper.cpp/models diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index 60aa2b1..ee19df8 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,6 +1,6 @@ ---- ggml-metal.m.orig 2023-12-15 06:38:34 -+++ ggml-metal.m 2023-12-15 06:38:35 -@@ -231,7 +231,7 @@ +--- ggml-metal.m.orig 2024-01-19 10:55:39 ++++ ggml-metal.m 2024-01-19 10:58:09 +@@ -254,7 +254,7 @@ WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); // Configure context @@ -9,7 +9,7 @@ ctx->device = device; ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS); ctx->queue = [ctx->device newCommandQueue]; -@@ -265,7 +265,7 @@ +@@ -288,7 +288,7 @@ if (ggmlMetalPathResources) { sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"]; } else { @@ -18,33 +18,35 @@ } if (sourcePath == nil) { WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); -@@ -451,8 +451,6 @@ - void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { - WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); - #define WSP_GGML_METAL_DEL_KERNEL(name) \ -- [ctx->function_##name release]; \ -- [ctx->pipeline_##name release]; - - WSP_GGML_METAL_DEL_KERNEL(add); - WSP_GGML_METAL_DEL_KERNEL(add_row); -@@ -565,16 +563,6 @@ - WSP_GGML_METAL_DEL_KERNEL(sum_rows); +@@ -530,27 +530,7 @@ - #undef WSP_GGML_METAL_DEL_KERNEL + static void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { + WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); - - for (int i = 0; i < ctx->n_buffers; ++i) { - [ctx->buffers[i].metal release]; - } - +- for (int i = 0; i < WSP_GGML_METAL_MAX_KERNELS; ++i) { +- if (ctx->kernels[i].pipeline) { +- [ctx->kernels[i].pipeline release]; +- } +- +- if (ctx->kernels[i].function) { +- [ctx->kernels[i].function release]; +- } +- } + - [ctx->library release]; - [ctx->queue release]; - [ctx->device release]; - - dispatch_release(ctx->d_queue); - +- free(ctx); } -@@ -2380,7 +2368,6 @@ + +@@ -2278,7 +2258,6 @@ g_backend_device_ref_count--; if (g_backend_device_ref_count == 0) { @@ -52,11 +54,13 @@ g_backend_device = nil; } } -@@ -2394,7 +2381,6 @@ - static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { +@@ -2292,9 +2271,6 @@ + WSP_GGML_CALL static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context; -- [ctx->metal release]; +- for (int i = 0; i < ctx->n_buffers; i++) { +- [ctx->buffers[i].metal release]; +- } wsp_ggml_backend_metal_free_device(); - free(ctx->data); + if (ctx->owned) { diff --git a/scripts/whisper-encoder.mm.patch b/scripts/whisper-encoder.mm.patch deleted file mode 100644 index 5beb10b..0000000 --- a/scripts/whisper-encoder.mm.patch +++ /dev/null @@ -1,14 +0,0 @@ ---- whisper-encoder.mm.orig 2023-10-05 09:49:44 -+++ whisper-encoder.mm 2023-10-05 09:49:59 -@@ -24,9 +24,9 @@ - - // select which device to run the Core ML model on - MLModelConfiguration *config = [[MLModelConfiguration alloc] init]; -- config.computeUnits = MLComputeUnitsCPUAndGPU; -+ //config.computeUnits = MLComputeUnitsCPUAndGPU; - //config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; -- //config.computeUnits = MLComputeUnitsAll; -+ config.computeUnits = MLComputeUnitsAll; - - const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]); - diff --git a/scripts/whisper.cpp.patch b/scripts/whisper.cpp.patch index 09697f8..ef33614 100644 --- a/scripts/whisper.cpp.patch +++ b/scripts/whisper.cpp.patch @@ -1,25 +1,25 @@ ---- whisper.cpp.orig 2023-12-09 10:50:29 -+++ whisper.cpp 2023-12-09 10:50:30 -@@ -3043,8 +3043,10 @@ +--- whisper.cpp.orig 2024-01-19 11:01:00 ++++ whisper.cpp 2024-01-19 11:01:01 +@@ -3072,8 +3072,10 @@ const size_t memory_size = wsp_ggml_nbytes(state->kv_cross.k) + wsp_ggml_nbytes(state->kv_cross.v); WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } + - + #ifdef WHISPER_USE_COREML + if (ctx->params.use_coreml) { const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); - + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); -@@ -3059,6 +3061,7 @@ +@@ -3088,6 +3090,7 @@ #endif } else { WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); + } } #endif - -@@ -3184,6 +3187,7 @@ + +@@ -3213,6 +3216,7 @@ struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { /*.use_gpu =*/ true, diff --git a/src/version.json b/src/version.json index 250af09..477da68 100644 --- a/src/version.json +++ b/src/version.json @@ -1 +1 @@ -{"version":"1.5.2"} \ No newline at end of file +{"version":"1.5.4"} \ No newline at end of file diff --git a/whisper.cpp b/whisper.cpp index 940de9d..c0329ac 160000 --- a/whisper.cpp +++ b/whisper.cpp @@ -1 +1 @@ -Subproject commit 940de9dbe9c90624dc99521cb34c8a97b86d543c +Subproject commit c0329acde8a7d2b03add7e7c8f5e5341b48746ff