Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24

#define GGML_MROPE_SECTIONS 4

#define GGML_UNUSED(x) (void)(x)

#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
Expand Down Expand Up @@ -1456,7 +1458,7 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
Expand All @@ -1482,6 +1484,22 @@ extern "C" {
float beta_fast,
float beta_slow);

GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);

GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
84 changes: 44 additions & 40 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3541,6 +3541,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
Expand All @@ -3554,15 +3555,19 @@ static struct ggml_tensor * ggml_rope_impl(

GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);

bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
if (mrope_used) {
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
} else {
GGML_ASSERT(a->ne[2] == b->ne[0]);
}

if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}

int sections[4] = {0, 0, 0, 0};

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
Expand All @@ -3572,7 +3577,11 @@ static struct ggml_tensor * ggml_rope_impl(
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &sections, sizeof(int)*4);
if (mrope_used) {
memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
} else {
memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
}
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_ROPE;
Expand All @@ -3590,7 +3599,7 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
);
}

Expand All @@ -3600,7 +3609,7 @@ struct ggml_tensor * ggml_rope_multi(
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
Expand All @@ -3609,36 +3618,31 @@ struct ggml_tensor * ggml_rope_multi(
float attn_factor,
float beta_fast,
float beta_slow) {
// Multimodal Rotary Position Embedding
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");

GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token

if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}

struct ggml_tensor * result = ggml_dup_tensor(ctx, a);

int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(&params[11], sections, sizeof(int)*4);
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_ROPE;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return ggml_rope_impl(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false
);
}

return result;
struct ggml_tensor * ggml_rope_multi_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true
);
}

struct ggml_tensor * ggml_rope_inplace(
Expand All @@ -3648,7 +3652,7 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
);
}

Expand All @@ -3667,7 +3671,7 @@ struct ggml_tensor * ggml_rope_ext(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false
);
}
Expand All @@ -3687,7 +3691,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true
);
}
Expand All @@ -3706,7 +3710,7 @@ struct ggml_tensor * ggml_rope_custom(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false
);
}
Expand All @@ -3725,7 +3729,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true
);
}
Expand Down
Loading