Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add ggml_fft and ggml_ifft operator #1105

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAMW,

GGML_OP_FFT, // Fast Fourier Transform
GGML_OP_IFFT, // Inverse Fast Fourier Transform

GGML_OP_COUNT,
};

Expand Down Expand Up @@ -1775,6 +1778,15 @@ extern "C" {
struct ggml_tensor * a,
int k);

// Fast Fourier Transform operations
GGML_API struct ggml_tensor * ggml_fft(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_ifft(
struct ggml_context * ctx,
struct ggml_tensor * a);

#define GGML_KQ_MASK_PAD 64

// q: [n_embd, n_batch, n_head, 1]
Expand Down
133 changes: 133 additions & 0 deletions src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <errno.h>
#include <time.h>
#include <math.h>
#include <complex.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
Expand Down Expand Up @@ -12716,6 +12717,123 @@ static void ggml_compute_forward_opt_step_adamw(
}
}
}

// FFT

static bool is_power_of_2(int64_t n) {
return n > 0 && (n & (n - 1)) == 0;
}

// Recursive implementation of FFT
static void fft_recursive(float complex *x, int64_t n, bool inverse) {
if (n <= 1) return;

// Split into even and odd
int64_t half = n / 2;
float complex *even = (float complex *)malloc(half * sizeof(float complex));
float complex *odd = (float complex *)malloc(half * sizeof(float complex));

for (int64_t i = 0; i < half; i++) {
even[i] = x[2*i];
odd[i] = x[(2*i+1)];
}

// Recursive FFT on even and odd parts
fft_recursive(even, half, inverse);
fft_recursive(odd, half, inverse);

// Combine results
float angle_factor = inverse ? 2.0 * M_PI / n : -2.0 * M_PI / n;
for (int64_t k = 0; k < half; k++) {
float complex t = cexpf(angle_factor * k * I) * odd[k];
x[k] = even[k] + t;
x[k+half] = even[k] - t;
}

free(even);
free(odd);
}

static void ggml_compute_forward_fft(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src = dst->src[0];

GGML_ASSERT(src->type == GGML_TYPE_F32); // Only support f32 for now
GGML_ASSERT(ggml_is_vector(src)); // Only support 1D FFT for now
GGML_ASSERT(is_power_of_2(src->ne[0])); // Length must be power of 2

if (params->ith != 0) {
return;
}

// Allocate temporary complex array
int64_t n = src->ne[0];
float complex *x = (float complex *)malloc(n * sizeof(float complex));

// Copy input data to complex array
float *src_data = (float *)src->data;
for (int64_t i = 0; i < n; i++) {
x[i] = src_data[i] + 0.0f * I;
}

// Perform FFT
fft_recursive(x, n, false);

// Copy result back
float *dst_data = (float *)dst->data;
for (int64_t i = 0; i < n; i++) {
// Store real and imaginary parts in interleaved format
dst_data[2*i] = crealf(x[i]);
dst_data[2*i+1] = cimagf(x[i]);
}

free(x);
}

static void ggml_compute_forward_ifft(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src = dst->src[0];

GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_vector(src));
GGML_ASSERT(is_power_of_2(src->ne[0]/2));

if (params->ith != 0) {
return;
}

// Allocate temporary complex array
int64_t n = src->ne[0];
float complex *x = (float complex *)malloc(n * sizeof(float complex));

// Copy input data to complex array
float *src_data = (float *)src->data;
for (int64_t i = 0; i < n; i++) {
x[i] = src_data[2*i] + src_data[2*i+1]*I;
}

// Perform recursive IFFT
fft_recursive(x, n, true);

// Scale the result by 1/N
float scale = 1.0f / n;
for (int64_t i = 0; i < n; i++) {
x[i] *= scale;
}

// Copy result back (real part only for IFFT)
float *dst_data = (float *)dst->data;
for (int64_t i = 0; i < n; i++) {
dst_data[i] = crealf(x[i]);
}

free(x);
}

/////////////////////////////////

static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
Expand Down Expand Up @@ -13081,6 +13199,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_opt_step_adamw(params, tensor);
}
break;
case GGML_OP_FFT:
{
ggml_compute_forward_fft(params, tensor);
}
break;
case GGML_OP_IFFT:
{
ggml_compute_forward_ifft(params, tensor);
}
break;
case GGML_OP_NONE:
{
// nop
Expand Down Expand Up @@ -13367,6 +13495,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
GGML_ABORT("fatal error");
}
case GGML_OP_FFT:
case GGML_OP_IFFT:
{
n_tasks = 1;
} break;
default:
{
fprintf(stderr, "%s: op not implemented: ", __func__);
Expand Down
45 changes: 43 additions & 2 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -990,9 +990,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW",

"FFT",
"IFFT",
};

static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1087,9 +1090,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"adamw(x)",

"fft(x)",
"ifft(x)",
};

static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -5130,6 +5136,41 @@ struct ggml_tensor * ggml_opt_step_adamw(
return result;
}

// ggml_fft

static struct ggml_tensor * ggml_fft_impl(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_vector(a));

struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, a->ne[0]);

result->op = GGML_OP_FFT;
result->src[0] = a;

return result;
}

struct ggml_tensor * ggml_fft(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_fft_impl(ctx, a);
}

// ggml_ifft

struct ggml_tensor * ggml_ifft(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_vector(a));
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, a->ne[0]/2);

result->op = GGML_OP_IFFT;
result->src[0] = a;

return result;
}

////////////////////////////////////////////////////////////////////////////////

struct ggml_hash_set ggml_hash_set_new(size_t size) {
Expand Down
15 changes: 15 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,21 @@ endif()
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")

#
# test-fft

set(TEST_TARGET test-fft)
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
if (MSVC)
target_link_options(${TEST_TARGET} PRIVATE "/STACK: 8388608") # 8MB
endif()
if (MATH_LIBRARY)
target_link_libraries(${TEST_TARGET} PRIVATE ${MATH_LIBRARY})
endif()
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")

#
# test-pad-reflect-1d

Expand Down
97 changes: 97 additions & 0 deletions tests/test-fft.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <complex.h>

#include "ggml.h"
#include "ggml-cpu.h"

#define N_SAMPLES 16 // Must be power of 2

// Helper function to generate a simple test signal
void generate_test_signal(float * signal, int n) {
// Generate a simple sinusoidal signal
for (int i = 0; i < n; i++) {
signal[i] = sinf(2.0f * M_PI * i / n) + 0.5f * sinf(4.0f * M_PI * i / n);
}
}

// Helper function to compare arrays with tolerance
bool compare_arrays(float * a, float * b, int n, float tolerance) {
for (int i = 0; i < n; i++) {
if (fabsf(a[i] - b[i]) > tolerance) {
printf("Mismatch at index %d: %f != %f\n", i, a[i], b[i]);
return false;
}
}
return true;
}

int main(int argc, const char ** argv) {
struct ggml_init_params params = {
.mem_size = 128*1024*1024,
.mem_buffer = NULL,
.no_alloc = false,
};

// initialize the backend
struct ggml_context * ctx = ggml_init(params);

// Create test signal
float input_signal[N_SAMPLES];
float output_signal[N_SAMPLES];
generate_test_signal(input_signal, N_SAMPLES);

// Create tensors
struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N_SAMPLES);
struct ggml_tensor * fft_result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2 * N_SAMPLES);
struct ggml_tensor * ifft_result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N_SAMPLES);

// Copy input signal to tensor
memcpy(input->data, input_signal, N_SAMPLES * sizeof(float));

// Create compute graph
struct ggml_cgraph * gf = ggml_new_graph(ctx);
struct ggml_cgraph * gb = ggml_new_graph(ctx);

// Perform FFT
fft_result = ggml_fft(ctx, input);
ggml_build_forward_expand(gf, fft_result);

// Perform IFFT
ifft_result = ggml_ifft(ctx, fft_result);
ggml_build_forward_expand(gb, ifft_result);

// Compute the graphs
ggml_graph_compute_with_ctx(ctx, gf, 1);
ggml_graph_compute_with_ctx(ctx, gb, 1);

// Copy result back
memcpy(output_signal, ifft_result->data, N_SAMPLES * sizeof(float));

// Compare input and output
const float tolerance = 1e-5f;
bool success = compare_arrays(input_signal, output_signal, N_SAMPLES, tolerance);

if (success) {
printf("FFT/IFFT test passed! Signal was correctly reconstructed within tolerance %f\n", tolerance);
} else {
printf("FFT/IFFT test failed! Signal reconstruction error exceeded tolerance %f\n", tolerance);

// Print signals for comparison
printf("\nOriginal signal:\n");
for (int i = 0; i < N_SAMPLES; i++) {
printf("%f ", input_signal[i]);
}
printf("\n\nReconstructed signal:\n");
for (int i = 0; i < N_SAMPLES; i++) {
printf("%f ", output_signal[i]);
}
printf("\n");
}

ggml_free(ctx);

return success ? 0 : 1;
}