Skip to content

Commit f73bc64

Browse files
Added pytorch frontend corrcoef and its tests (ivy-llc#20953)
Signed-off-by: vishwajit <[email protected]> Co-authored-by: Tugay Gül <[email protected]>
1 parent f069324 commit f73bc64

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

ivy/functional/frontends/torch/miscellaneous_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,13 @@ def view_as_real(input):
426426
re_part = ivy.real(input)
427427
im_part = ivy.imag(input)
428428
return ivy.stack((re_part, im_part), axis=-1)
429+
430+
431+
@to_ivy_arrays_and_back
432+
def corrcoef(input):
433+
if len(ivy.shape(input)) > 2:
434+
raise ivy.exceptions.IvyError(
435+
"corrcoef(): expected input to have two or fewer dimensions but got an"
436+
f" input with {ivy.shape(input)} dimansions"
437+
)
438+
return ivy.corrcoef(input, y=None, rowvar=True)

ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,3 +1584,37 @@ def test_torch_view_as_real(
15841584
on_device=on_device,
15851585
input=np.asarray(x[0], dtype=input_dtype[0]),
15861586
)
1587+
1588+
1589+
# corrcoef
1590+
@handle_frontend_test(
1591+
fn_tree="torch.corrcoef",
1592+
dtypes_and_x=helpers.dtype_and_values(
1593+
available_dtypes=helpers.get_dtypes("float"),
1594+
num_arrays=1,
1595+
min_num_dims=2,
1596+
max_num_dims=2,
1597+
min_dim_size=2,
1598+
max_dim_size=2,
1599+
min_value=1,
1600+
),
1601+
test_with_out=st.just(False),
1602+
)
1603+
def test_torch_corrcoef(
1604+
dtypes_and_x,
1605+
frontend,
1606+
fn_tree,
1607+
on_device,
1608+
test_flags,
1609+
backend_fw,
1610+
):
1611+
input_dtypes, x = dtypes_and_x
1612+
helpers.test_frontend_function(
1613+
input_dtypes=["float64"],
1614+
frontend=frontend,
1615+
fn_tree=fn_tree,
1616+
test_flags=test_flags,
1617+
on_device=on_device,
1618+
backend_to_test=backend_fw,
1619+
input=x[0],
1620+
)

0 commit comments

Comments
 (0)