Skip to content

Added ProMPs and Kinematics functionalities #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ This library (or variations thereof) has been successfully utilized in the follo

[G. Clark, J. Campbell, and H. Ben Amor. Learning Predictive Models for Ergonomic Control of Prosthetic Devices](https://arxiv.org/pdf/2011.07005.pdf)

[V. Prasad, R. Stock-Homburg, and J. Peters. Learning Human-like Hand Reaching for Human-Robot Handshaking](https://arxiv.org/abs/2103.00616)

## Acknowledgements

This work was supported in part by the National Science Foundation under grant No. IIS-1749783 and the Honda Research Institute.
Expand Down
442 changes: 442 additions & 0 deletions docs/notebooks/5_promp_example.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions intprim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from intprim.bayesian_interaction_primitives import *
from intprim.probabilistic_movement_primitives import *
import intprim.basis
import intprim.constants
import intprim.examples
Expand Down
8 changes: 6 additions & 2 deletions intprim/basis/basis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, degree, observed_dof_names):
# Gets the block diagonal basis matrix for the given phase value(s).
# Used to transform vectors from the basis space to the measurement space.
#
# @param x Scalar of vector of dimension T containing the phase values to use in the creation of the block diagonal matrix.
# @param x Scalar or vector of dimension T containing the phase values to use in the creation of the block diagonal matrix.
# @param out_array Matrix of dimension greater to or equal than (degree * num_observed_dof * T) x num_observed_dof in which the results are stored. If none, an internal matrix is used.
# @param start_row A row offset to apply to results in the block diagonal matrix.
# @param start_col A column offset to apply to results in the block diagonal matrix.
Expand All @@ -46,6 +46,8 @@ def get_block_diagonal_basis_matrix(self, x, out_array = None, start_row = 0, st
out_array = self.block_prototype

basis_funcs = self.get_basis_functions(x)
if np.isscalar(x):
basis_funcs = basis_funcs[:, None]
for block_index in range(self._num_blocks):
out_array[start_row + block_index * self._degree : start_row + (block_index + 1) * self._degree, start_col + block_index : start_col + block_index + 1] = basis_funcs

Expand All @@ -67,6 +69,8 @@ def get_block_diagonal_basis_matrix_derivative(self, x, out_array = None, start_
out_array = self.block_prototype

basis_funcs = self.get_basis_function_derivatives(x)
if np.isscalar(x):
basis_funcs = basis_funcs[:, None]
for block_index in range(self._num_blocks):
out_array[start_row + block_index * self._degree : start_row + (block_index + 1) * self._degree, start_col + block_index : start_col + block_index + 1] = basis_funcs

Expand All @@ -88,7 +92,7 @@ def get_weighted_vector_derivative(self, x, weights, out_array = None, start_row
out_array = np.zeros((1, self._degree))

out_row = start_row
basis_func_derivs = self.get_basis_function_derivatives(x[0])
basis_func_derivs = self.get_basis_function_derivatives(x)

# temp_weights = self.inverse_transform(weights)
temp_weights = weights
Expand Down
2 changes: 1 addition & 1 deletion intprim/basis/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This module defines the GaussianModel class.
#
# @author Joseph Campbell <[email protected]>, Interactive Robotics Lab, Arizona State University
import basis_model
from intprim.basis import basis_model
import intprim.constants
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion intprim/basis/mixture_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This module defines the MixtureModel class.
#
# @author Joseph Campbell <[email protected]>, Interactive Robotics Lab, Arizona State University
import basis_model
from intprim.basis import basis_model
import intprim.constants
import numpy as np
import scipy.linalg
Expand Down
2 changes: 1 addition & 1 deletion intprim/basis/polynomial_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This module defines the PolynomialModel class.
#
# @author Joseph Campbell <[email protected]>, Interactive Robotics Lab, Arizona State University
import basis_model
from intprim.basis import basis_model
import intprim.constants
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion intprim/basis/sigmoidal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This module defines the SigmoidalModel class.
#
# @author Joseph Campbell <[email protected]>, Interactive Robotics Lab, Arizona State University
import basis_model
from intprim.basis import basis_model
import intprim.constants
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion intprim/bayesian_interaction_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_approximate_trajectory(self, trajectory, num_samples = intprim.constants
# @return approximate_trajectory Matrix of dimension D x num_samples containing the approximate trajectory.
#
def get_approximate_trajectory_derivative(self, trajectory, num_samples = intprim.constants.DEFAULT_NUM_SAMPLES):
return get_approximate_trajectory(trajectory, num_samples, deriv = True)
return self.get_approximate_trajectory(trajectory, num_samples, deriv = True)

##
# Gets the probability distribution of the trained demonstrations.
Expand Down
7 changes: 4 additions & 3 deletions intprim/examples/minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@

# Compute the phase mean and phase velocities from the demonstrations.
phase_velocity_mean, phase_velocity_var = intprim.examples.get_phase_stats(training_trajectories)

mean_w, cov_w = primitive.get_basis_weight_parameters()
# Define a filter to use. Here we use an ensemble Kalman filter
filter = intprim.filter.spatiotemporal.EnsembleKalmanFilter(
filter = intprim.filter.spatiotemporal.ExtendedKalmanFilter(
basis_model = basis_model,
initial_phase_mean = [0.0, phase_velocity_mean],
initial_phase_var = [1e-4, phase_velocity_var],
proc_var = 1e-8,
initial_ensemble = primitive.basis_weights)
mean_basis_weights = mean_w,
cov_basis_weights = cov_w)



Expand Down
231 changes: 231 additions & 0 deletions intprim/examples/promp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
### In order to run this, please first download the sample data at https://github.com/sebasutp/promp/blob/master/examples/strike_mov.npz
### This code is adapted from the code at https://github.com/sebasutp/promp/

import intprim
from intprim.probabilistic_movement_primitives import *
import numpy as np
import matplotlib.pyplot as plt
from intprim.util.kinematics import BaseKinematicsClass
import sys


# Download the sample data from https://github.com/sebasutp/promp/blob/master/examples/strike_mov.npz and provide it as an argument while running this file
if len(sys.argv)!=2:
print('Download the sample data from https://github.com/sebasutp/promp/blob/master/examples/strike_mov.npz and provide it as an argument while running this file')
print ('Usage: python promp_example.py /path/to/strike_mov.npz')

with open(sys.argv[1],'rb') as f:
data = np.load(f, allow_pickle=True, encoding='bytes')
time = data['time']
Q = data['Q']
num_joints = Q[0].shape[1]

# Create a ProMP with Gaussian basis functions.
basis_model = intprim.basis.GaussianModel(8, 0.1, ['joint'+str(i) for i in range(num_joints)])
promp = ProMP(basis_model)

# Add Demonstrations to the ProMP, which in turn calculates the list of weights for each demonstration.
for i in range(len(Q)):
promp.add_demonstration(Q[i].T)

# Plot samples from the learnt ProMP

n_samples = 20 # Number of trajectoies to sample
plot_dof = 3 # Degree of freedom to plot
domain = np.linspace(0,1,100)

for i in range(n_samples):
samples, _ = promp.generate_probable_trajectory(domain)
plt.plot(domain, samples[plot_dof,:], 'g--', alpha=0.3)

mean_margs = np.zeros(samples[plot_dof,:].shape)
upper_bound = np.zeros(samples[plot_dof,:].shape)
lower_bound = np.zeros(samples[plot_dof,:].shape)
for i in range(len(domain)):
mu_marg_q, Sigma_marg_q = promp.get_marginal(domain[i])
std_q = Sigma_marg_q[plot_dof][plot_dof] ** 0.5

mean_margs[i] = mu_marg_q[plot_dof]
upper_bound[i] = mu_marg_q[plot_dof] + std_q
lower_bound[i] = mu_marg_q[plot_dof] - std_q

plt.fill_between(domain, upper_bound, lower_bound, color = 'g', alpha=0.2)
plt.plot(domain, mean_margs, 'g-')
plt.title('Samples for DoF {}'.format(plot_dof))
plt.show()


q_cond_init = [1.54, 0.44, 0.15, 1.65, 0.01, -0.09, -1.23]
t = 0

# Condition the ProMP and obtain the posterior distribution.
mu_w_cond, Sigma_w_cond = promp.get_conditioned_weights(t, q_cond_init)

# Plot samples of the conditioned ProMP drawn from the predicted posterior.
plt.plot(0,q_cond_init[plot_dof],'bo')
for i in range(n_samples):
samples, _ = promp.generate_probable_trajectory(domain, mu_w_cond, Sigma_w_cond)
plt.plot(domain, samples[plot_dof,:], 'b--', alpha=0.3)

for i in range(len(domain)):
mu_marg_q, Sigma_marg_q = promp.get_marginal(domain[i], mu_w_cond, Sigma_w_cond)
std_q = Sigma_marg_q[plot_dof][plot_dof] ** 0.5

mean_margs[i] = mu_marg_q[plot_dof]
upper_bound[i] = mu_marg_q[plot_dof] + std_q
lower_bound[i] = mu_marg_q[plot_dof] - std_q

plt.fill_between(domain, upper_bound, lower_bound, color = 'b', alpha=0.2)
plt.plot(domain, mean_margs, 'b-')
plt.title('Joint Space Conditioned Samples for DoF {}'.format(plot_dof))
plt.show()


q_cond = np.random.choice(Q)
q_domain = np.linspace(0, 1, len(q_cond))

# Condition the ProMP and obtain the posterior distribution.
mu_w_cond_rec, Sigma_w_cond_rec = promp.get_conditioned_weights(q_domain[0], q_cond[0])
num_observed = 25
for i in range(1, num_observed):
mu_w_cond_rec, Sigma_w_cond_rec = promp.get_conditioned_weights(q_domain[i], q_cond[i], mean_w=mu_w_cond_rec, var_w=Sigma_w_cond_rec)
# Plot samples of the conditioned ProMP drawn from the predicted posterior.
plt.plot(q_domain[:num_observed],q_cond[:num_observed,plot_dof],'bo', alpha=0.3)
plt.plot(q_domain[num_observed:],q_cond[num_observed:,plot_dof],'ro', alpha=0.3)
for i in range(n_samples):
samples, _ = promp.generate_probable_trajectory(domain, mu_w_cond_rec, Sigma_w_cond_rec)
plt.plot(domain, samples[plot_dof,:], 'b--', alpha=0.3)

mean_margs = np.zeros(samples[plot_dof,:].shape)
upper_bound = np.zeros(samples[plot_dof,:].shape)
lower_bound = np.zeros(samples[plot_dof,:].shape)
std_Q = []
for i in range(len(domain)):
mu_marg_q, Sigma_marg_q = promp.get_marginal(domain[i], mu_w_cond_rec, Sigma_w_cond_rec)
std_q = Sigma_marg_q[plot_dof][plot_dof]
std_Q.append(std_q)
mean_margs[i] = mu_marg_q[plot_dof]
upper_bound[i] = mu_marg_q[plot_dof] + std_q
lower_bound[i] = mu_marg_q[plot_dof] - std_q
plt.fill_between(domain, upper_bound, lower_bound, color = 'b', alpha=0.2)
plt.plot(domain, mean_margs, 'b-')
plt.title('Joint Space Recursively Conditioned Samples for DoF {}'.format(plot_dof))
plt.show()

class BarrettKinematics(BaseKinematicsClass):
''' Forward kinematics object for the Barrett Wam
This class implements the forwark kinematics functionality for the
Barrett Wam arm used in the table tennis setup at the MPI. The end
effector position can be changes with the endeff parameter received
in the constructor.

The code is taken from https://github.com/sebasutp/promp/blob/master/robpy/kinematics/forward.py
'''

def __init__(self, endeff = [0.0, 0.0, 0.3, 0.0, 0.0, 0.0]):
self.ZSFE = 0.346
self.ZHR = 0.505
self.YEB = 0.045
self.ZEB = 0.045
self.YWR = -0.045
self.ZWR = 0.045
self.ZWFE = 0.255
self.endeff = endeff

def _link_matrices(self,q):
cq = np.cos(q)
sq = np.sin(q)

sa=np.sin(self.endeff[3])
ca=np.cos(self.endeff[3])

sb=np.sin(self.endeff[4])
cb=np.cos(self.endeff[4])

sg=np.sin(self.endeff[5])
cg=np.cos(self.endeff[5])

hi00 = np.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]])
hi01 = np.array([[cq[0],-sq[0],0,0],[sq[0],cq[0],0,0],[0,0,1,self.ZSFE],[0,0,0,1]])
hi12 = np.array([[0,0,-1,0],[sq[1],cq[1],0,0],[cq[1],-sq[1],0,0],[0,0,0,1]])
hi23 = np.array([[0,0,1,self.ZHR],[sq[2],cq[2],0,0],[-cq[2],sq[2],0,0],[0,0,0,1]])
hi34 = np.array([[0,0,-1,0],[sq[3],cq[3],0,self.YEB],[cq[3],-sq[3],0,self.ZEB],[0,0,0,1]])
hi45 = np.array([[0,0,1,self.ZWR],[sq[4],cq[4],0,self.YWR],[-cq[4],sq[4],0,0],[0,0,0,1]])
hi56 = np.array([[0,0,-1,0],[sq[5],cq[5],0,0],[cq[5],-sq[5],0,self.ZWFE],[0,0,0,1]])
hi67 = np.array([[0,0,1,0],[sq[6],cq[6],0,0],[-cq[6],sq[6],0,0],[0,0,0,1]])
hi78 = np.array([[cb*cg, -(cb*sg), sb, self.endeff[0]], \
[cg*sa*sb + ca*sg, ca*cg - sa*sb*sg, -(cb*sa), self.endeff[1]], \
[-(ca*cg*sb) + sa*sg, cg*sa + ca*sb*sg, ca*cb, self.endeff[2]], \
[0,0,0,1]])
return [hi00,hi01,hi12,hi23,hi34,hi45,hi56,hi67,hi78]

mu_cartesian = np.array([-0.62, -0.44, -0.34])
Sigma_cartesian = 0.02**2*np.eye(3)

# Get the prior distribution of the joint space at the desired time.
t = 1
prior_mu_q, prior_Sigma_q = promp.get_marginal(t)

# Compute the posterior distribution of the joint space using inverse kinematics.
fwd_kin = BarrettKinematics()
mu_q, Sigma_q = fwd_kin.inv_kin(prior_mu_q, prior_Sigma_q, mu_cartesian, Sigma_cartesian)

# Finally, condition the ProMP using the posterior joint space distribution.
mu_w_task, Sigma_w_task = promp.get_conditioned_weights(t, mu_q, Sigma_q)

# Plot samples of the conditioned ProMP.
std_q = Sigma_q[plot_dof][plot_dof] ** 0.5
plt.errorbar(t, mu_q[plot_dof], yerr=std_q, color='r', fmt='--o', capsize=4)
for i in range(n_samples):
samples, _ = promp.generate_probable_trajectory(domain, mu_w_task, Sigma_w_task)
plt.plot(domain, samples[plot_dof,:], 'r--', alpha=0.3)

for i in range(len(domain)):
mu_marg_q, Sigma_marg_q = promp.get_marginal(domain[i], mu_w_task, Sigma_w_task)
std_q = Sigma_marg_q[plot_dof][plot_dof] ** 0.5

mean_margs[i] = mu_marg_q[plot_dof]
upper_bound[i] = mu_marg_q[plot_dof] + std_q
lower_bound[i] = mu_marg_q[plot_dof] - std_q

plt.fill_between(domain, upper_bound, lower_bound, color = 'r', alpha=0.2)
plt.plot(domain, mean_margs, 'r-')
plt.title('Task Space Conditioned Samples for DoF {}'.format(plot_dof))
plt.show()


t = 0

# First condition the ProMP in the joint space.
mu_w_cond, Sigma_w_cond = promp.get_conditioned_weights(t, q_cond_init)
plt.plot(0,q_cond_init[plot_dof],'mo')

# Get the prior distribution of the joint space from the conditioned ProMP.
t = 1
prior_mu_q, prior_Sigma_q = promp.get_marginal(t, mu_w_cond, Sigma_w_cond)

# Compute the posterior distribution of the joint space using inverse kinematics.
mu_q, Sigma_q = fwd_kin.inv_kin(prior_mu_q, prior_Sigma_q, mu_cartesian, Sigma_cartesian)

# Finally, condition the ProMP using the posterior joint space distribution.
mu_w_task, Sigma_w_task = promp.get_conditioned_weights(t, mu_q, Sigma_q, mu_w_cond, Sigma_w_cond)

# Plot samples of the conditioned ProMP.
std_q = Sigma_q[plot_dof][plot_dof] ** 0.5
plt.errorbar(t, mu_q[plot_dof], yerr=std_q, color='m', fmt='--o', capsize=4)
for i in range(n_samples):
samples, _ = promp.generate_probable_trajectory(domain, mu_w_task, Sigma_w_task)
plt.plot(domain, samples[plot_dof,:], 'm--', alpha=0.3)

for i in range(len(domain)):
mu_marg_q, Sigma_marg_q = promp.get_marginal(domain[i], mu_w_task, Sigma_w_task)
std_q = Sigma_marg_q[plot_dof][plot_dof] ** 0.5

mean_margs[i] = mu_marg_q[plot_dof]
upper_bound[i] = mu_marg_q[plot_dof] + std_q
lower_bound[i] = mu_marg_q[plot_dof] - std_q

plt.fill_between(domain, upper_bound, lower_bound, color = 'm', alpha=0.2)
plt.plot(domain, mean_margs, 'm-')
plt.title('Joint and Task Space Conditioned Samples for DoF {}'.format(plot_dof))
plt.show()
3 changes: 3 additions & 0 deletions intprim/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from intprim.filter.spatiotemporal import *
from intprim.filter.align import *
from intprim.filter.linear_system import *
from intprim.filter.kf import *
Loading