Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] mlx.core.conv_general is really slow #1409

Open
valfrom opened this issue Sep 13, 2024 · 4 comments
Open

[Performance] mlx.core.conv_general is really slow #1409

valfrom opened this issue Sep 13, 2024 · 4 comments

Comments

@valfrom
Copy link

valfrom commented Sep 13, 2024

Describe the bug
Method mlx.core.conv_general is significantly slower than PyTorch analog. Can vary from 10x to 150x slower.

To Reproduce
Just run the attached code.

Include code snippet

import mlx.core as mx
import time
import torch


def mlx_sample():
    x = mx.random.normal([8, 16, 128, 128, 32], dtype=mx.float32)
    weight = mx.random.normal([4, 1, 1, 1, 32], dtype=mx.float32)
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]
    start = time.time()
    n = 10
    for _ in range(n):
        out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
        mx.eval(out)

    print(f'MLX time: {(time.time() - start) * 1000 / n:0.2f}ms')


def torch_sample():
    x = torch.randn([8, 32, 16, 128, 128], dtype=torch.float32, device='mps')
    weight = torch.randn([4, 32, 1, 1, 1], dtype=torch.float32, device='mps')
    bias = torch.randn([4], dtype=torch.float32, device='mps')
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]
    start = time.time()
    n = 10
    for _ in range(n):
        out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
        out.max()

    print(f'MPS time: {(time.time() - start) * 1000 / n:0.2f}ms')


def main():
    mlx_sample()
    torch_sample()


if __name__ == '__main__':
    main()

Output:
MLX time: 20.65ms
MPS time: 0.93ms

Expected behavior
At least the same speed as in PyTorch.

Desktop (please complete the following information):

  • OS Version: [e.g. MacOS 14.6.1]
  • Version [e.g. 0.17.3]
  • Tested on M2 Ultra and M1 Pro Max
@valfrom valfrom changed the title [Performance] [Performance] mlx.core.conv_general is really slow Sep 13, 2024
@awni
Copy link
Member

awni commented Sep 13, 2024

Indeed, we are aware that there are performance cliffs in our convolutions, see e.g. #1313

Thanks for the benchmark though! We will make sure to include the 3D stuff in our optimizations.

FYI, there are a few problems with your benchmark:

  • Make sure you eval the inputs before running the function you want to benchmark
  • Always use a warmup (this also ensures the above will be true)
  • Make sure you synchronize on the MPS device as well

Here is an improved version:

def mlx_sample():
    x = mx.random.normal([8, 16, 128, 128, 32], dtype=mx.float32)
    weight = mx.random.normal([4, 1, 1, 1, 32], dtype=mx.float32)
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]

    # Warmup
    for _ in range(5):
        out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
        mx.eval(out)

    start = time.time()
    n = 10
    for _ in range(n):
        out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
        mx.eval(out)
    print(f'MLX time: {(time.time() - start) * 1000 / n:0.2f}ms')


def torch_sample():
    x = torch.randn([8, 32, 16, 128, 128], dtype=torch.float32, device='mps')
    weight = torch.randn([4, 32, 1, 1, 1], dtype=torch.float32, device='mps')
    bias = torch.randn([4], dtype=torch.float32, device='mps')
    stride = [1, 1, 1]
    padding = [0, 0, 0]
    dilation = [1, 1, 1]
    # Warmup
    for _ in range(5):
        out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
        torch.mps.synchronize()

    start = time.time()
    n = 10
    for _ in range(n):
        out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
        torch.mps.synchronize()

    print(f'MPS time: {(time.time() - start) * 1000 / n:0.2f}ms')

Finally, for a 1x1x1 convolution, I'd encourage you to use a Linear layer / matmul. That will be way faster for now. I ran the revised benchmark on an M2 Ultra:

MLX time: 5.26ms
MPS time: 1.38ms

MLX with matmul instead:

MLX time: 0.90ms

That looks like this:

    x = mx.random.normal([8, 16 * 128 * 128, 32], dtype=mx.float32)
    weight = mx.random.normal([4, 32], dtype=mx.float32)
    out = x @ weight.T

@valfrom
Copy link
Author

valfrom commented Sep 13, 2024

Thanks a lot

@valfrom
Copy link
Author

valfrom commented Sep 15, 2024

After some digging, the main issue is channel size of the input and the first dimension of the weight. PyTorch has implemented convolution using native code by utilising MPSGraphConvolution3DOpDescriptor.
P.S. convolution with next params can't be calculated at all due to an error "libc++abi: terminating due to uncaught exception of type std::runtime_error: [METAL] Command buffer execution failed: Internal Error (0000000e:Internal Error)" I guess kernel timeout or something like this.

    x = mx.random.normal([8, 142, 16, 64, 64], dtype=mx.float32)
    weight = mx.random.normal([22, 142, 7, 7, 7], dtype=mx.float32)
    bias = mx.random.normal([22], dtype=mx.float32)
    stride = [1, 1, 1]
    padding = [3, 3, 3]
    dilation = [1, 1, 1]

Is there a way to get a metal buffer from an array? Then I'll be probably able to use the native function to calculate 3d convolution.

@awni
Copy link
Member

awni commented Sep 16, 2024

You have a couple of options:

  • use the Python buffer protocol memoryview(a)
  • get the DLPack capsule a.__dlpack__()

More info in the docs on converting to other frameworks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants