Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exporting MultiInputActorCriticPolicy as ONNX #1873

Open
4 tasks done
MaximCamilleri opened this issue Mar 18, 2024 · 10 comments · May be fixed by #2098
Open
4 tasks done

Exporting MultiInputActorCriticPolicy as ONNX #1873

MaximCamilleri opened this issue Mar 18, 2024 · 10 comments · May be fixed by #2098
Labels
documentation Improvements or additions to documentation help wanted Help from contributors is welcomed question Further information is requested

Comments

@MaximCamilleri
Copy link

❓ Question

Hi,

I am looking into the use of ONNX with SB3. I have tested 2 models (A2C and PPO) on a custom environment using a MultiInputActorCriticPolicy. The observation space of the environment is of type dict. So far I have not been able to produce an onnaxable policy.

In the documentation the words The following examples are for MlpPolicy only, and are general examples can be found. Is it possible to export a model of my type to ONNX? and if so would it be possible to provide an example?

Thanks

Checklist

@MaximCamilleri MaximCamilleri added the question Further information is requested label Mar 18, 2024
@araffin araffin added the more information needed Please fill the issue template completely label Mar 18, 2024
@araffin
Copy link
Member

araffin commented Mar 18, 2024

Hello,
what have you tried so far?
and what errors did you encounter?

Please provide a minimal and working code example (see link in issue template for what that means).

@MaximCamilleri
Copy link
Author

Hello, thanks for your response.

I have tried a couple of things so far. First I tried converting my model into an onnxable policy using the method shown in the documentation. My code is as follows:

class OnnxablePolicy(th.nn.Module):
    def __init__(self, policy):
        super(OnnxablePolicy2, self).__init__()
        self.policy = policy

    def forward(self, input):
        return self.policy(input)

model = PPO.load("Models/ppo.zip")
onnx_policy = OnnxablePolicy(model.policy)

th.onnx.export(
    onnx_policy,
    obs_dict,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

To get the dummy input which I am here calling obs_dict, I used the following code snippet:

obs = env.reset()
obs_dict = {}
for key in obs.keys():
    obs_dict[key] = th.from_numpy(np.array([obs[key]])).float()

This creates an input with the same structure as the observation space after common.preprocessing.preprocess_obs is run.
The error I was getting at this point is: TypeError: OnnxablePolicy2.forward() missing 1 required positional argument: 'input'

I also tried the approach seen here, and created the following code:

class OnnxablePolicy(th.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxablePolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, input):
        action_hidden = value_hidden = self.extractor(input)
        return self.action_net(action_hidden), self.value_net(value_hidden)

onnx_policy = OnnxablePolicy(model.policy.features_extractor, model.policy.action_net, model.policy.value_net)

th.onnx.export(
    onnx_policy,
    obs_dict,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

Which resulted in the same error as before.

Finally I tried using the policy as is:

model = PPO.load("Models/ppo.zip")
obs = env.reset()
th.onnx.export(
    model.policy,
    obs,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

This seemingly got me the furthest, producing the new error:

[110](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:110)     assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
    [111](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:111)     preprocessed_obs = {}
    [112](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:112)     for key, _obs in obs.items():
AssertionError: Expected dict, got <class 'torch.Tensor'>

@araffin
Copy link
Member

araffin commented Mar 25, 2024

I gave it a try but this one seems to be a bit hard, you probably need to use the experimental onnx export from pytorch (using dynamo).
The thing that got me further was to pass (obs_dict, {}) as observation, otherwise pytorch try to use it as keyword arguments.

my current attempt (the export seems to work but the loading doesn't :/)

import torch as th
from typing import Tuple
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy

import onnx
import onnxruntime as ort
import numpy as np


class OnnxableSB3Policy(th.nn.Module):
    def __init__(self, policy: BasePolicy):
        super().__init__()
        self.policy = policy

    def forward(self, observation):
        print(observation)
        return observation["a"]
        # NOTE: Preprocessing is included, but postprocessing
        # (clipping/inscaling actions) is not,
        # If needed, you also need to transpose the images so that they are channel first
        # use deterministic=False if you want to export the stochastic policy
        return self.policy._predict(observation, deterministic=True)


class Custom(gym.Env):
    def __init__(self):
        super().__init__()
        self.observation_space = gym.spaces.Dict(
            {
                "a": gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32),
                # "b": gym.spaces.Discrete(5),
            }
        )
        self.action_space = gym.spaces.Discrete(2)

    def reset(self, seed=None):
        return self.observation_space.sample(), {}

    def step(self, action):
        return self.observation_space.sample(), 0.0, False, False, {}


env = Custom()
obs, _ = env.reset()
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MultiInputPolicy", env).save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")


onnx_policy = OnnxableSB3Policy(model.policy)

observation_size = model.observation_space.shape
# Add batch dimension
dummy_input = {
    # "a": np.array(obs["a"])[np.newaxis, ...],
    "a": np.array(obs["a"]),
    # "b": np.array(obs["b"])[np.newaxis, ...],
}
dummy_input_tensor = {
    "a": th.as_tensor(dummy_input["a"]),
    # "b": th.as_tensor(dummy_input["b"]),
}

print(model.predict(dummy_input, deterministic=True))


th.onnx.export(
    onnx_policy,
    args=(dummy_input_tensor, {}),
    f="my_ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

##### Load and test with onnx


onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

observation = dummy_input.copy()
ort_sess = ort.InferenceSession(onnx_path)

# print(ort_sess.get_inputs()[0].name)
# print(ort_sess.get_inputs())

output = ort_sess.run(None, {"input": observation})

print(output)

# Check that the predictions are the same
# with th.no_grad():
#     print(model.policy(th.as_tensor(observation), deterministic=True))

@araffin
Copy link
Member

araffin commented Mar 31, 2024

"
Due to design differences, input/output format between PyTorch model and exported ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model, but only flattened tensors are supported by ONNX, etc."

from https://pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.ONNXProgram.adapt_torch_inputs_to_onnx

@NickLucche
Copy link
Contributor

HI all, I wouldn't really export the sampling procedure to onnx here (''self.policy._predict(observation, deterministic=True)
"), but rather have the network output the raw logits and implement the sampling as a postprocessing step.
A consistent export procedure would be a nice feature to add to the framework :)

@pulasthibandara
Copy link

Just wondering if there has been any progress here? I've got the export to work, but when I try to predict, it requires a bunch of "_obs.17", ""_obs.23" ... etc observations which the original model doesn't require.

@darkopetrovic
Copy link

darkopetrovic commented Mar 2, 2025

Just started working with ONNX and was looking for a solution for the MultiInputPolicy.

After a lot of trials and errors, I finally come up with a solution using the th.onnx.dynamo_export export.

  • To make it work we first need to add a batch dimension to each key of the observation. In the example below, for the single environment, a simple .reshape(1, -1) will do the trick. With the vectorized environment there is nothing do to as the batch dimension will correspond to the number of environments.

  • The other point is that we need to correctly name the keys of the onnx runtime input dictionary (onnxruntime_input in example below) with the keys created in the session.

import torch as th
import onnxruntime as ort
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.env_util import make_vec_env

USE_VECTORIZED_ENV = True


class OnnxableSB3Policy(th.nn.Module):
  def __init__(self, policy: BasePolicy):
    super().__init__()
    self.policy = policy

  def forward(self, observation):
    return self.policy(observation, deterministic=True)


class Custom(gym.Env):
  def __init__(self):
    super().__init__()
    self.observation_space = gym.spaces.Dict(
        {
            "a": gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32),
            "b": gym.spaces.Box(low=-1, high=1, shape=(6,), dtype=np.float32),
            "c": gym.spaces.Box(low=-1, high=1, shape=(9,), dtype=np.float32),
        }
    )
    self.action_space = gym.spaces.Box(low=-1, high=1, shape=(10,), dtype=np.float32)

  def reset(self, seed=None):
    return self.observation_space.sample(), {}

  def step(self, action):
    return self.observation_space.sample(), 0.0, False, False, {}


if not USE_VECTORIZED_ENV:
  # Single environment
  env = Custom()
  PPO("MultiInputPolicy", env).save("PathToTrainedModel")
  obs, _ = env.reset()
  # -> we need to add a batch dimension
  obs = {k: v.reshape(1, -1) for k, v in obs.items()}
else:
  # Vectorized environment
  vec_env = make_vec_env(Custom, n_envs=8)
  PPO("MultiInputPolicy", vec_env).save("PathToTrainedModel")
  # -> batch dimension already defined by the number of environments
  obs = vec_env.reset()

model = PPO.load("PathToTrainedModel.zip", device="cpu")

# Convert to ONNX
onnx_policy = OnnxableSB3Policy(model.policy)
obs_tensor = obs_as_tensor(obs, model.policy.device)
onnx_program = th.onnx.dynamo_export(onnx_policy, obs_tensor).save("my-model.onnx")

# Load ONNX model and run
ort_session = ort.InferenceSession("my-model.onnx")
onnxruntime_input = {k.name: v for k, v in zip(ort_session.get_inputs(), obs.values())}
onnxruntime_outputs, _, _ = ort_session.run(None, onnxruntime_input)

# Run PyTorch model and compare
torch_outputs, _ = model.predict(obs, deterministic=True)
th.testing.assert_close(torch_outputs, onnxruntime_outputs, rtol=1e-06, atol=1e-6)

# Debug
print("Observations:")
for key, value in obs.items():
  print(f"{key}: {value.shape}")

print("\nONNX input:")
for key, value in onnxruntime_input.items():
  print(f"{key}: {value.shape}")

Output:

Observations:
a: (8, 3)
b: (8, 6)
c: (8, 9)

ONNX input:
_obs: (8, 3)
_obs_1: (8, 6)
_obs_2: (8, 9)

There is however some warnings in the output:

FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead. param_schemas = callee.param_schemas()

FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.

UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.

Packages versions:

  • stable_baselines3: 2.4.0
  • torch: 2.2.2
  • onnx: 1.17.0
  • onnxruntime: 1.20.1
  • onnxscript: 0.2.1

@araffin
Copy link
Member

araffin commented Mar 5, 2025

Hello,
thanks for the solution =)

Could you do a PR that adds a link in the documentation (section "exporting models") that redirects here?

@araffin araffin added documentation Improvements or additions to documentation help wanted Help from contributors is welcomed and removed more information needed Please fill the issue template completely labels Mar 5, 2025
@darkopetrovic
Copy link

Updated solution as dynamo_export will be deprecated:

import torch as th
import onnxruntime as ort
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.env_util import make_vec_env

USE_VECTORIZED_ENV = True


class OnnxableSB3Policy(th.nn.Module):
  def __init__(self, policy: BasePolicy):
    super().__init__()
    self.policy = policy

  def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
    return self.policy(observation, deterministic=True)


class Custom(gym.Env):
  def __init__(self):
    super().__init__()
    self.observation_space = gym.spaces.Dict(
        {
            "a": gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32),
            "b": gym.spaces.Box(low=-1, high=1, shape=(6,), dtype=np.float32),
            "c": gym.spaces.Box(low=-1, high=1, shape=(9,), dtype=np.float32),
        }
    )
    self.action_space = gym.spaces.Box(low=-1, high=1, shape=(10,), dtype=np.float32)

  def reset(self, seed=None):
    return self.observation_space.sample(), {}

  def step(self, action):
    return self.observation_space.sample(), 0.0, False, False, {}


if not USE_VECTORIZED_ENV:
  # Single environment
  env = Custom()
  PPO("MultiInputPolicy", env).save("PathToTrainedModel")
  obs, _ = env.reset()
  # -> we need to add a batch dimension
  obs = {k: v.reshape(1, -1) for k, v in obs.items()}
else:
  # Vectorized environment
  vec_env = make_vec_env(Custom, n_envs=8)
  PPO("MultiInputPolicy", vec_env).save("PathToTrainedModel")
  # -> batch dimension already defined by the number of environments
  obs = vec_env.reset()

model = PPO.load("PathToTrainedModel.zip", device="cpu")

# Convert to ONNX
onnx_policy = OnnxableSB3Policy(model.policy)
obs_tensor = obs_as_tensor(obs, model.policy.device)

model_input = {
    "observation": obs_tensor
}

onnx_program = th.onnx.export(
    onnx_policy,
    args=(model_input,),
    f="my-model.onnx",
    dynamo=True
)

# Load ONNX model and run
ort_session = ort.InferenceSession("my-model.onnx")
onnxruntime_input = {k.name: v for k, v in zip(ort_session.get_inputs(), obs.values())}
onnxruntime_outputs, _, _ = ort_session.run(None, onnxruntime_input)

# Run PyTorch model and compare
torch_outputs, _ = model.predict(obs, deterministic=True)
th.testing.assert_close(torch_outputs, onnxruntime_outputs, rtol=1e-06, atol=1e-6)

# Debug
print("Observations:")
for key, value in obs.items():
  print(f"{key}: {value.shape}")

print("\nONNX input:")
for key, value in onnxruntime_input.items():
  print(f"{key}: {value.shape}")

Output:

Observations:
a: (8, 3)
b: (8, 6)
c: (8, 9)

ONNX input:
observation_a: (8, 3)
observation_b: (8, 6)
observation_c: (8, 9)

Packages version:

  • stable_baselines3: 2.5.0
  • torch: 2.6.0
  • onnx: 1.17.0
  • onnxruntime: 1.20.1
  • onnxscript: 0.2.1

@araffin araffin linked a pull request Mar 10, 2025 that will close this issue
16 tasks
@darkopetrovic
Copy link

Could you do a PR that adds a link in the documentation (section "exporting models") that redirects here?

PR created #2098

Forgive me if I did not permed all the required check. For a simple sentence I thought that it is not necessary :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation help wanted Help from contributors is welcomed question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants