Skip to content

Commit 6dffbaa

Browse files
ggml: backward pass for split swiglu
1 parent 343b6e9 commit 6dffbaa

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

ggml/src/ggml.c

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6009,12 +6009,28 @@ static void ggml_compute_backward(
60096009
}
60106010
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
60116011
} break;
6012+
case GGML_OP_GLU: {
6013+
switch (ggml_get_glu_op(tensor)) {
6014+
case GGML_GLU_OP_SWIGLU: {
6015+
if (src0_needs_grads) {
6016+
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
6017+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
6018+
}
6019+
if (src1_needs_grads) {
6020+
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
6021+
}
6022+
} break;
6023+
default: {
6024+
GGML_LOG_ERROR("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
6025+
} break;
6026+
}
6027+
} break;
60126028
case GGML_OP_NONE: {
60136029
// noop
60146030
} break;
60156031
case GGML_OP_COUNT:
60166032
default: {
6017-
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
6033+
GGML_LOG_ERROR("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
60186034
GGML_ABORT("fatal error");
60196035
} //break;
60206036
}

tests/test-backend-ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case {
11751175
if (v & 1) {
11761176
auto ne = ne_a; ne[0] *= 3;
11771177
a = ggml_new_tensor(ctx, type, 4, ne.data());
1178+
ggml_set_param(a);
11781179
ggml_set_name(a, "a");
11791180

11801181
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
11811182
ggml_set_name(a, "view_of_a");
11821183

11831184
b = ggml_new_tensor(ctx, type, 4, ne.data());
1185+
ggml_set_param(b);
11841186
ggml_set_name(b, "b");
11851187

11861188
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
11871189
ggml_set_name(a, "view_of_b");
11881190
} else {
11891191
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1192+
ggml_set_param(a);
11901193
ggml_set_name(a, "a");
11911194

11921195
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
1196+
ggml_set_param(b);
11931197
ggml_set_name(b, "b");
11941198
}
11951199

0 commit comments

Comments
 (0)