Skip to content

Commit

Permalink
use cpu rope operator to avoid bugs, rever me after fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
hipudding committed Jul 11, 2024
1 parent 0267af3 commit 759a92f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
15 changes: 15 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,21 @@ extern "C" {

GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);

struct ggml_compute_params {
// ith = thread index, nth = number of threads
int ith, nth;

// work buffer for all threads
size_t wsize;
void * wdata;

struct ggml_compute_state_shared * shared;
};

void ggml_compute_forward_rope(
const struct ggml_compute_params * params,
struct ggml_tensor * dst);

#ifdef __cplusplus
}
#endif
42 changes: 42 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2247,10 +2247,52 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
}

void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO
// ROPE on NPU has some bugs that not found yet.
// Use CPU function instead to have a quick try of NPU backend.
aclrtSynchronizeStream(ctx.stream());
ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors

size_t src0_len = ggml_nbytes(src0);
size_t src1_len = ggml_nbytes(src1);
size_t dst_len = ggml_nbytes(dst);
void* src0_host = malloc(src0_len);
void* src1_host = malloc(src1_len);
void* dst_host = malloc(dst_len);

void* src0_dev_ptr = src0->data;
void* src1_dev_ptr = src1->data;
void* dst_dev_ptr = dst->data;

aclrtMemcpy(src0_host, src0_len, src0_dev_ptr, src0_len, ACL_MEMCPY_DEVICE_TO_HOST);
aclrtMemcpy(src1_host, src1_len, src1_dev_ptr, src1_len, ACL_MEMCPY_DEVICE_TO_HOST);

src0->data = src0_host;
src1->data = src1_host;
dst->data = dst_host;

ggml_compute_params param1;
param1.ith=0;
param1.nth=1;
param1.wdata = malloc(102400);

ggml_compute_forward_rope(&param1, dst);

aclrtMemcpy(dst_dev_ptr, dst_len, dst_host, dst_len, ACL_MEMCPY_HOST_TO_DEVICE);

src0->data = src0_dev_ptr;
src1->data = src1_dev_ptr;
dst->data = dst_dev_ptr;

free(src0_host);
free(src1_host);
free(dst_host);
free(param1.wdata);

return;

// TODO: with freq_factors
GGML_ASSERT(src2 == NULL);

Expand Down
11 changes: 1 addition & 10 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1754,16 +1754,7 @@ struct ggml_compute_state {
struct ggml_compute_state_shared * shared;
};

struct ggml_compute_params {
// ith = thread index, nth = number of threads
int ith, nth;

// work buffer for all threads
size_t wsize;
void * wdata;

struct ggml_compute_state_shared * shared;
};

//
// fundamental operations
Expand Down Expand Up @@ -14055,7 +14046,7 @@ static void ggml_compute_forward_rope_f16(
}
}

static void ggml_compute_forward_rope(
void ggml_compute_forward_rope(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

Expand Down

0 comments on commit 759a92f

Please sign in to comment.