Skip to content

Commit

Permalink
update HIP_UMA ggml-org#7399
Browse files Browse the repository at this point in the history
add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable.
- get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103)
  • Loading branch information
Djip007 committed May 20, 2024
1 parent d359f30 commit b2c0f7f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
24 changes: 21 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ int ggml_cuda_get_device() {
return id;
}

// ggml_cuda_host_malloc
static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
#if defined(GGML_USE_HIPBLAS)
#if defined(GGML_HIP_UMA)
auto res = hipMallocManaged(ptr, size);
if (res == hipSuccess) {
// if error we "need" to know why...
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
}
return res;
#else
return hipMalloc(ptr, size);
#endif
#else
return cudaMalloc(ptr, size);
#endif
}

static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
Expand Down Expand Up @@ -271,7 +289,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
*actual_size = look_ahead_size;
pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
Expand Down Expand Up @@ -537,7 +555,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0

void * dev_ptr;
cudaError_t err = cudaMalloc(&dev_ptr, size);
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();
Expand Down Expand Up @@ -798,7 +816,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
ggml_cuda_set_device(id);
char * buf;
CUDA_CHECK(cudaMalloc(&buf, size));
CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));

// set padding to 0 to avoid possible NaN values
if (size > original_size) {
Expand Down
2 changes: 0 additions & 2 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#ifdef GGML_HIP_UMA
#define cudaMalloc hipMallocManaged
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif
#define cudaMemcpy hipMemcpy
Expand Down

0 comments on commit b2c0f7f

Please sign in to comment.