Skip to content

Commit

Permalink
further refine super tile size; implement a proper performance compar…
Browse files Browse the repository at this point in the history
…ison
  • Loading branch information
bssrdf committed Dec 18, 2024
1 parent 4f93d67 commit 1d606ed
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/ggml-cuda/conv-winograd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor *
cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int));
cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int));

conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8,
conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 16, 2,
tile_size, tile_2d_s,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
Expand Down
4 changes: 2 additions & 2 deletions src/ggml-cuda/conv-winograd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#define BC 8
#define BN 32
#define BK 64
#define TW 8
#define TH 16
#define TW 32
#define TH 4
#define BN_p 138

__constant__ int access_f_s[2][32];
Expand Down
187 changes: 162 additions & 25 deletions tests/test-conv2d-winograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
//#include <cuda_runtime.h>
#endif

#ifdef GGML_USE_METAL
Expand Down Expand Up @@ -36,8 +37,10 @@ struct test_model {

void load_model(test_model & model, bool use_gpu = false) {
// create data
int KW = 3, KH = 3, IC = 32, OC = 64;
int IW = 28, IH = 40, N = 1;
int KW = 3, KH = 3, IC = 256, OC = 256;
int IW = 832, IH = 1216, N = 1;

printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH);

// Initialize adata
std::vector<float> adata(KW * KH * IC * OC);
Expand Down Expand Up @@ -135,7 +138,7 @@ void load_model(test_model & model, bool use_gpu = false) {
}
}

struct ggml_cgraph * build_graph(const test_model& model) {
struct ggml_cgraph * build_graph_0(const test_model& model) {
static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);

Expand Down Expand Up @@ -163,21 +166,104 @@ struct ggml_cgraph * build_graph(const test_model& model) {
struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
ggml_set_name(conv2d_res, "conv2d_res");
ggml_build_forward_expand(gf, conv2d_res);
int64_t *ne = conv2d_res->ne;
printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
// int64_t *ne = conv2d_res->ne;
// printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);


// struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b);
// ggml_set_name(wino_res, "wino_res");
// ggml_build_forward_expand(gf, wino_res);
// ne = wino_res->ne;
// printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
ggml_free(ctx0);
return gf;
}

struct ggml_cgraph * build_graph_1(const test_model& model) {
static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);

struct ggml_init_params params0 = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf.data(),
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};

// create a temporally context to build the graph
struct ggml_context * ctx0 = ggml_init(params0);

struct ggml_cgraph * gf = ggml_new_graph(ctx0);

int s0 = 1;
int s1 = 1;
int p0 = 1;
int p1 = 1;
int d0 = 1;
int d1 = 1;



// recalculate for avoid fragmentation
// struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
// ggml_set_name(conv2d_res, "conv2d_res");
// ggml_build_forward_expand(gf, conv2d_res);
// int64_t *ne = conv2d_res->ne;
// printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);


struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b);
ggml_set_name(wino_res, "wino_res");
ggml_build_forward_expand(gf, wino_res);
ne = wino_res->ne;
printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
// ne = wino_res->ne;
// printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
ggml_free(ctx0);
return gf;
}

struct ggml_cgraph * compute_graph(const test_model & model, ggml_gallocr_t allocr) {
struct ggml_cgraph * gf = build_graph(model);
struct ggml_cgraph * compute_graph_0(const test_model & model, ggml_gallocr_t allocr) {
struct ggml_cgraph * gf = build_graph_0(model);

// allocate tensors
ggml_gallocr_alloc_graph(allocr, gf);
int n_threads = 1;

if (ggml_backend_is_cpu(model.backend)) {
ggml_backend_cpu_set_n_threads(model.backend, n_threads);
}

#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(model.backend)) {
ggml_backend_metal_set_n_cb(model.backend, n_threads);
}
#endif

int iterations = 20;

ggml_backend_graph_compute(model.backend, gf);

ggml_backend_synchronize(model.backend);

int64_t start_time = ggml_time_us();

for(int iter=0; iter<iterations; iter++){
ggml_backend_graph_compute(model.backend, gf);
}

ggml_backend_synchronize(model.backend);
int64_t end_time = ggml_time_us();
double time_us = end_time - start_time;

time_us = time_us/iterations;
printf(" Taking %f ms\n ", time_us/1000);

//ggml_graph_print(gf);

return gf;
}


struct ggml_cgraph * compute_graph_1(const test_model & model, ggml_gallocr_t allocr) {
struct ggml_cgraph * gf = build_graph_1(model);

// allocate tensors
ggml_gallocr_alloc_graph(allocr, gf);
Expand All @@ -193,8 +279,25 @@ struct ggml_cgraph * compute_graph(const test_model & model, ggml_gallocr_t allo
}
#endif

int iterations = 20;

ggml_backend_graph_compute(model.backend, gf);

ggml_backend_synchronize(model.backend);

int64_t start_time = ggml_time_us();

for(int iter=0; iter<iterations; iter++){
ggml_backend_graph_compute(model.backend, gf);
}

ggml_backend_synchronize(model.backend);
int64_t end_time = ggml_time_us();
double time_us = end_time - start_time;

time_us = time_us/iterations;
printf(" Taking %f ms\n ", time_us/1000);

//ggml_graph_print(gf);

return gf;
Expand All @@ -213,24 +316,58 @@ int main(void)
allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));

//create the worst case graph for memory usage estimation
struct ggml_cgraph * gf = build_graph(model);
struct ggml_cgraph * gf = build_graph_0(model);

// compute the required memory
ggml_gallocr_reserve(allocr, gf);
size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);
fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
}

struct ggml_cgraph * gf_res_0 = NULL;

gf_res_0 = compute_graph_0(model, allocr);


// ggml_gallocr_t allocr = NULL;

{
allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));

//create the worst case graph for memory usage estimation
struct ggml_cgraph * gf = build_graph_1(model);

// compute the required memory
ggml_gallocr_reserve(allocr, gf);
size_t mem_size = ggml_gallocr_get_buffer_size(allocr, 0);
fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
}

struct ggml_cgraph * gf_res = compute_graph(model, allocr);
struct ggml_cgraph * gf_res_1 = NULL;

gf_res_1 = compute_graph_1(model, allocr);






struct ggml_tensor * wino_res = NULL;
struct ggml_tensor * conv2d_res = NULL;

for(int i = 0; i < ggml_graph_n_nodes(gf_res); ++i) {
if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "wino_res") == 0) {
wino_res = ggml_graph_node(gf_res, i);
} else if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "conv2d_res") == 0) {
conv2d_res = ggml_graph_node(gf_res, i);
for(int i = 0; i < ggml_graph_n_nodes(gf_res_0); ++i) {
if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "wino_res") == 0) {
wino_res = ggml_graph_node(gf_res_0, i);
} else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_0, i)), "conv2d_res") == 0) {
conv2d_res = ggml_graph_node(gf_res_0, i);
}
}

for(int i = 0; i < ggml_graph_n_nodes(gf_res_1); ++i) {
if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "wino_res") == 0) {
wino_res = ggml_graph_node(gf_res_1, i);
} else if(strcmp(ggml_get_name(ggml_graph_node(gf_res_1, i)), "conv2d_res") == 0) {
conv2d_res = ggml_graph_node(gf_res_1, i);
}
}

Expand All @@ -245,15 +382,15 @@ int main(void)

bool passed = true;
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
for(int i = 0; i < 3*28; i++) {
float diff = fabs(conv2d_data[i] - wino_data[i]);
// if(diff > 1.e-4) {
printf("(%f, %f, %f, %d) \n",
conv2d_data[i],
wino_data[i], diff, i);
// break;
// }
}
// for(int i = 0; i < 3*28; i++) {
// float diff = fabs(conv2d_data[i] - wino_data[i]);
// // if(diff > 1.e-4) {
// printf("(%f, %f, %f, %d) \n",
// conv2d_data[i],
// wino_data[i], diff, i);
// // break;
// // }
// }



Expand Down

0 comments on commit 1d606ed

Please sign in to comment.