Skip to content

Conversation

hsjts0u
Copy link

@hsjts0u hsjts0u commented Sep 4, 2025

Add default args for _aten_conv2d, which would otherwise fail in the following code snippet

import torch
from torch.export import export_for_training
import torchax
from torchax import interop
from torch.utils import _pytree as pytree
import jax
from torchax.ops import mappings

class Simple(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=4, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        return x
    

model = Simple()

exported = export_for_training(model, (torch.randn(1, 3, 224, 224),))

def make_shape_struct(x):
    return jax.ShapeDtypeStruct(x.shape, mappings.t2j_dtype(x.dtype))


def map_nth(v, i):
    def f(t):
        if isinstance(t, torch.Tensor):
            return t[i : i + 1]
        return t

    return pytree.tree_map(f, v)


env = torchax.default_env()
with env:
    model = exported.module().to("jax")

    def func_to_export(x):
        # hard code weights in model
        return model(x)

    example_inputs_jax = pytree.tree_map_only(
        torch.Tensor, lambda x: x.to("jax"), map_nth(exported.example_inputs, 0)
    )

    res = jax.jit(interop.jax_view(func_to_export)).lower(*example_inputs_jax[0])

# TypeError: _aten_conv2d() missing 5 required positional arguments: 'bias', 'stride', 'padding', 'dilation', and 'groups'

cc @qihqi

@hsjts0u hsjts0u changed the title Add default args for _aten_con2d Add default args for _aten_conv2d Sep 4, 2025
@qihqi qihqi enabled auto-merge (squash) September 18, 2025 00:23
@qihqi
Copy link
Collaborator

qihqi commented Sep 18, 2025

thanks!

@qihqi
Copy link
Collaborator

qihqi commented Sep 19, 2025

Hi @hsjts0u would you rebase to latest HEAD? it should fix the CI issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants