-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MLX Backend #19571
Comments
PyTest Output=========================================================================== test session starts ============================================================================
platform darwin -- Python 3.12.2, pytest-8.1.1, pluggy-1.4.0 -- /Users/kartheek/erlang-ws/github-ws/latest/keras/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /Users/kartheek/erlang-ws/github-ws/latest/keras
configfile: pyproject.toml
plugins: cov-5.0.0
collected 6 items
keras/src/ops/operation_test.py::OperationTest::test_autoconfig PASSED [ 16%]
keras/src/ops/operation_test.py::OperationTest::test_eager_call PASSED [ 33%]
keras/src/ops/operation_test.py::OperationTest::test_input_conversion FAILED [ 50%]
keras/src/ops/operation_test.py::OperationTest::test_serialization PASSED [ 66%]
keras/src/ops/operation_test.py::OperationTest::test_symbolic_call PASSED [ 83%]
keras/src/ops/operation_test.py::OperationTest::test_valid_naming PASSED [100%]
================================================================================= FAILURES =================================================================================
___________________________________________________________________ OperationTest.test_input_conversion ____________________________________________________________________
self = <keras.src.ops.operation_test.OperationTest testMethod=test_input_conversion>
def test_input_conversion(self):
x = np.ones((2,))
y = np.ones((2,))
z = knp.ones((2,)) # mix
if backend.backend() == "torch":
z = z.cpu()
op = OpWithMultipleInputs()
> out = op(x, y, z)
keras/src/ops/operation_test.py:152:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras/src/utils/traceback_utils.py:113: in error_handler
return fn(*args, **kwargs)
keras/src/ops/operation.py:56: in __call__
return self.call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Operation name=op_with_multiple_inputs>, x = array([1., 1.]), y = array([1., 1.])
z = <[ValueError('item can only be called on arrays of size 1.') raised in repr()] array object at 0x13f7450c0>
def call(self, x, y, z=None):
# `z` has to be put first due to the order of operations issue with
# torch backend.
> return 3 * z + x + 2 * y
E ValueError: Cannot perform addition on an mlx.core.array and ndarray
keras/src/ops/operation_test.py:14: ValueError
========================================================================= short test summary info ==========================================================================
FAILED keras/src/ops/operation_test.py::OperationTest::test_input_conversion - ValueError: Cannot perform addition on an mlx.core.array and ndarray
======================================================================= 1 failed, 5 passed in 0.13s ======================================================================== How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ? |
It's not fixable on our side, we should file an issue with the MLX repo. |
Thank you for the list. I am doing keras/backend/mlx/nn.py:conv |
I am working on segment_sum, segment_max, max_pool and avg_pool. Thank you . |
I want to take a stab at |
Thank you @yrahul3910 , please go ahead with adding |
FAILED keras/src/ops/numpy_test.py::NumpyDtypeTest::test_tensordot_('int16', 'bool') - ValueError: [matmul] Only real floating point types are supported but int16 and bool were provided which results in int16, which is not a real floating point type. @fchollet How do we handle this - we can cast integers arguments to float32 if both are integers and result will be float32. If we go this route, we have to modify test cases in |
Just want to let you all know some updates to MLX as of 0.16.1 that may be useful here:
Are there any high priority items we can fix or add to help move this along? |
Thank you @awni , we need some help in moving this forward. I will make a list and get back to you in a day or two. |
Issue for tracking and coordinating mlx backend work:
mlx.math
fft
fft2
rfft
irfft
stft
istft
logsumexp
mlx - add missing convert_to_tensor #19578qr
segment_sum
mlx - implement segment_sum and segment_max #19652segment_max
mlx - implement segment_sum and segment_max #19652erfinv
feat(math): support erfinv on mlx #19628mlx.numpy
einsum
bincount
nonzero
cross
vdot
nan_to_num
copy
roll
median
Implementedmedian(...)
function. #19568 Implement missing functions in mlx backend #19574meshgrid
Implement missing functions in mlx backend #19574conjugate
arctan2
Added arctan2 operation #19759quantile
imag
real
select
argpartition
mlx - add argpartition to numpy #19680slogdet
select
vectorize
correlate
diag
mlx - fix diag and diagonal in numpy #19714diagonal
mlx - fix diag and diagonal in numpy #19714mlx.image
rgb_to_grayscale
mlx - add rgb_to_grayscale #19609resize
- mlx - image.resize addcrop_to_aspect_ratio
argument #19699mlx.nn
max_pool
avg_pool
conv
depthwise_conv
separable_conv
conv_transpose
ctc_loss
mlx.rnn
rnn
lstm
gru
mlx.linalg
cholesky
det
eig
eigh
inv
lu_factor
norm
mlx - add linalg.norm #19698qr
solve
solve_triangular
svd
mlx.core
np.ndarray
of bfloat16 using ml_dtypes is being interpreted as complex64 ml-explore/mlx#1075The text was updated successfully, but these errors were encountered: