Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions src/zgemm_NN_gpu.jdf
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
extern "C" %{
/*
* Copyright (c) 2017-2024 The University of Tennessee and The University
* Copyright (c) 2017-2025 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
*
Expand Down Expand Up @@ -113,6 +113,23 @@ static int pred_z(int x, int y, int z, int xMax, int yMax, int zMax, int l)
return z;
}

/* Define the different shapes this JDF is using */
#define A_SHAPE 0
#define B_SHAPE 1
#define C_SHAPE 2

/* Assume the functions on type & type_remote will return parsec_arena_datatype_t */
#define JDF2C_TYPE_ADT_NOT_INDEX

/* Include the functions to obtain the parsec_arena_datatype_t */
#include "dplasmajdf_lapack_dtt.h"
//#define FULL_CONVERSION
#ifdef FULL_CONVERSION
#define ADTT_READ(dM, loc, shape, layout) ADTT_DC(dM, loc, shape, layout)
#else
#define ADTT_READ(dM, loc, shape, layout) ADTT_DC(dM, loc, shape, LAPACK)
#endif

%}

/* Keep this first, as in all jdf in this directory, to
Expand All @@ -126,9 +143,12 @@ transB [ type = int ]
alpha [ type = dplasma_complex64_t ]
beta [ type = dplasma_complex64_t ]

descA [ type = "const parsec_tiled_matrix_t*" ]
descB [ type = "const parsec_tiled_matrix_t*" ]
descC [ type = "parsec_tiled_matrix_t*" ]
ddescA [type = "dplasma_data_collection_t*"]
descA [type = "parsec_tiled_matrix_t*" hidden = on default = "((dplasma_data_collection_t*)ddescA)->dc_original" aligned=ddescA]
ddescB [type = "dplasma_data_collection_t*"]
descB [type = "parsec_tiled_matrix_t*" hidden = on default = "((dplasma_data_collection_t*)ddescB)->dc_original" aligned=ddescB]
ddescC [type = "dplasma_data_collection_t*"]
descC [type = "parsec_tiled_matrix_t*" hidden = on default = "((dplasma_data_collection_t*)ddescC)->dc_original" aligned=ddescC]

/*
* The process grid is tP x tQ
Expand All @@ -142,8 +162,8 @@ tQ [ type = int ]

LOOK_AHEAD [ type = int ]

nb_cuda_devices [ type = "int" ]
cuda_device_index [ type = "int *" ]
nb_gpu_devices [ type = "int" ]
gpu_device_index [ type = "int *" ]

xMax [ type = int default = "-1" hidden=on ]
yMax [ type = int default = "-1" hidden=on ]
Expand All @@ -163,12 +183,14 @@ READ_A(m, k, x, y, z)
z = k / tD .. k / tD
nmax = %{ int n1 = (y+1)*tC*tQ-1;
int n2 = descC->nt - 1;
return n1<n2 ? n1 : n2;
%}
return n1<n2 ? n1 : n2;
%}
loc_A = %{ return LOC(descA, m, k); %}

:descA(m, k)

READ A <- descA(m, k)
READ A <- descA(m, k) [ type = %{ return ADTT_READ(ddescA, loc_A, A_SHAPE, TILED); %}
type_data = %{ return ADTT_READ(ddescA, loc_A, A_SHAPE, LAPACK); %} ]
-> A GEMM(m, y*tC*tQ .. nmax, k)

CTL Y <- Y GLOBAL_BARRIER(x, y, z)
Expand All @@ -191,12 +213,14 @@ READ_B(k, n, x, y, z)
z = k / tD .. k / tD
mmax = %{ int m1 = (x+1)*tB*tP-1;
int m2 = descC->mt - 1;
return m1<m2 ? m1 : m2;
%}
return m1<m2 ? m1 : m2;
%}
loc_B = %{ return LOC(descB, k, n); %}

: descB(k, n)

READ B <- descB(k, n)
READ B <- descB(k, n) [ type = %{ return ADTT_READ(ddescB, loc_B, B_SHAPE, TILED); %}
type_data = %{ return ADTT_READ(ddescB, loc_B, B_SHAPE, LAPACK); %} ]
-> B GEMM(x*tB*tP .. mmax, n, k)

CTL Y <- Y GLOBAL_BARRIER(x, y, z)
Expand All @@ -216,19 +240,22 @@ READ_C(m, n)
u = r / tQ
v = r % tQ

loc_C = %{ return LOC(descC, m, n); %}

: descC(m, n)

READ C <- descC(m, n)
READ C <- descC(m, n) [ type = %{ return ADTT_READ(ddescC, loc_C, C_SHAPE, TILED); %}
type_data = %{ return ADTT_READ(ddescC, loc_C, C_SHAPE, LAPACK); %} ]
-> C GEMM(m, n, 0)

CTL Z <- Z LOCAL_BARRIER( m/(tB*tP), n/(tC*tQ), 0, u, v )

BODY
if( nb_cuda_devices > 0 ) {
int g = (n / tQ) % nb_cuda_devices;
if( nb_gpu_devices > 0 ) {
int g = (n / tQ) % nb_gpu_devices;
if( _f_C->original->preferred_device <= 0 ) {
parsec_advise_data_on_device( _f_C->original,
cuda_device_index[g],
gpu_device_index[g],
PARSEC_DEV_DATA_ADVICE_PREFERRED_DEVICE );
}
}
Expand Down Expand Up @@ -352,14 +379,19 @@ GEMM(m, n, k)
yn = %{ return succ_y(x, y, z, xMax, yMax, zMax, 1); %}
zn = %{ return succ_z(x, y, z, xMax, yMax, zMax, 1); %}

loc_A = %{ return LOC(descA, m, k); %}
loc_B = %{ return LOC(descB, k, n); %}
loc_C = %{ return LOC(descC, m, n); %}

: descC(m, n)

READ A <- A READ_A(m, k, x, y, z)
READ B <- B READ_B(k, n, x, y, z)
READ A <- A READ_A(m, k, x, y, z) [ type_remote = %{ return ADTT_DC(ddescA, loc_A, A_SHAPE, TILED); %} ]
READ B <- B READ_B(k, n, x, y, z) [ type_remote = %{ return ADTT_DC(ddescB, loc_B, B_SHAPE, TILED); %} ]
RW C <- k == 0 ? C READ_C(m, n)
: C GEMM(m, n, k-1 )
-> k + 1 == descB->mt ? descC(m, n)
: C GEMM(m, n, k+1)
<- k != 0 ? C GEMM(m, n, k-1 )
-> k + 1 == descB->mt ? descC(m, n) [ type = %{ return ADTT_CP(_f_C, ddescC, loc_C, C_SHAPE); %}
type_data = %{ return ADTT_DC(ddescC, loc_C, C_SHAPE, LAPACK); %} ]
-> k + 1 != descB->mt ? C GEMM(m, n, k+1) /* dep OUT: rely on datacopy dtt for sending */

CTL Z <- ( k > 0 ) & ((k % tD) == 0) ? Z LOCAL_BARRIER(x, y, z, u, v)
-> ((k == descB->mt-1) | (k == (z+1)*tD-1)) ? Z LOCAL_BARRIER(xn, yn, zn, u, v)
Expand Down
Loading