Skip to content

Commit d22f117

Browse files
author
Kwesi Rutledge
committed
Changed how we concretize matrix and vectors of monomials
1 parent 236ae42 commit d22f117

File tree

5 files changed

+193
-53
lines changed

5 files changed

+193
-53
lines changed

symbolic/monomial.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,11 @@ func (m Monomial) DerivativeWrt(vIn Variable) Expression {
543543
// monomial.
544544
var monomialOut Monomial
545545

546-
if m.Exponents[foundIndex] == 1 {
546+
switch {
547+
case (m.Exponents[foundIndex] == 1) && (len(m.VariableFactors) == 1):
548+
// Return the exponent
549+
return K(m.Coefficient)
550+
case m.Exponents[foundIndex] == 1:
547551
// If the degree of vIn is 1, then remove it from the monomial
548552
monomialOut.Coefficient = m.Coefficient
549553
for ii, variable := range m.VariableFactors {
@@ -552,7 +556,7 @@ func (m Monomial) DerivativeWrt(vIn Variable) Expression {
552556
monomialOut.Exponents = append(monomialOut.Exponents, m.Exponents[ii])
553557
}
554558
}
555-
} else {
559+
default:
556560
monomialOut = m
557561
monomialOut.Coefficient = m.Coefficient * float64(m.Exponents[foundIndex])
558562
monomialOut.Exponents[foundIndex] -= 1

symbolic/monomial_matrix.go

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -505,30 +505,19 @@ func (mm MonomialMatrix) DerivativeWrt(vIn Variable) Expression {
505505
}
506506

507507
// Compute the Derivative of each monomial
508-
var dmm MonomialMatrix
508+
var dmm [][]ScalarExpression
509509
for _, row := range mm {
510-
var dmmRow []Monomial
510+
var dmmRow []ScalarExpression
511511
for _, monomial := range row {
512512
dMonomial := monomial.DerivativeWrt(vIn)
513-
var dMonomialAsMonomial Monomial
514-
switch dMonomial.(type) {
515-
case Monomial:
516-
dMonomialAsMonomial = dMonomial.(Monomial)
517-
case K:
518-
dMonomialAsMonomial = dMonomial.(K).ToMonomial()
519-
default:
520-
panic(
521-
fmt.Errorf("unexpected type of derivative: %T (%v)", dMonomial, dMonomial),
522-
)
523-
}
524-
// Add the converted dMonomial to dmmRow
525-
dmmRow = append(dmmRow, dMonomialAsMonomial)
513+
dMonomialAsSE, _ := ToScalarExpression(dMonomial)
514+
dmmRow = append(dmmRow, dMonomialAsSE) // Add the converted dMonomial to dmmRow
526515
}
527516
dmm = append(dmm, dmmRow)
528517
}
529518

530519
// Return the derivative
531-
return dmm
520+
return ConcretizeMatrixExpression(dmm)
532521
}
533522

534523
/*

testing/symbolic/monomial_matrix_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,175 @@ func TestMonomialMatrix_Eq2(t *testing.T) {
16341634
}
16351635
}
16361636

1637+
/*
1638+
TestMonomialMatrix_DerivativeWrt1
1639+
Description:
1640+
1641+
This test checks that the DerivativeWrt() method properly panics when it is called
1642+
with a monomial matrix that is not well formed.
1643+
*/
1644+
func TestMonomialMatrix_DerivativeWrt1(t *testing.T) {
1645+
// Constants
1646+
var mm symbolic.MonomialMatrix
1647+
1648+
// Test
1649+
defer func() {
1650+
r := recover()
1651+
if r == nil {
1652+
t.Errorf(
1653+
"expected DerivativeWrt() to panic; it did not",
1654+
)
1655+
}
1656+
1657+
rAsE, ok := r.(error)
1658+
if !ok {
1659+
t.Errorf(
1660+
"expected DerivativeWrt() to panic with an error; it panicked with %v",
1661+
r,
1662+
)
1663+
}
1664+
1665+
expectedError := mm.Check()
1666+
if !strings.Contains(rAsE.Error(), expectedError.Error()) {
1667+
t.Errorf(
1668+
"expected DerivativeWrt() to panic with error \"%v\"; it panicked with \"%v\"",
1669+
expectedError,
1670+
rAsE,
1671+
)
1672+
}
1673+
}()
1674+
mm.DerivativeWrt(symbolic.NewVariable())
1675+
}
1676+
1677+
/*
1678+
TestMonomialMatrix_DerivativeWrt2
1679+
Description:
1680+
1681+
This test checks that the DerivativeWrt() method properly panics when it is called
1682+
with a monomial matrix that is well-formed and a variable that is not well-defined.
1683+
*/
1684+
func TestMonomialMatrix_DerivativeWrt2(t *testing.T) {
1685+
// Constants
1686+
v1 := symbolic.NewVariable()
1687+
badV := symbolic.Variable{}
1688+
m1 := v1.ToMonomial()
1689+
var mm symbolic.MonomialMatrix = [][]symbolic.Monomial{
1690+
{m1, m1},
1691+
{m1, m1},
1692+
}
1693+
1694+
// Test
1695+
defer func() {
1696+
r := recover()
1697+
if r == nil {
1698+
t.Errorf(
1699+
"expected DerivativeWrt() to panic; it did not",
1700+
)
1701+
}
1702+
1703+
rAsE, ok := r.(error)
1704+
if !ok {
1705+
t.Errorf(
1706+
"expected DerivativeWrt() to panic with an error; it panicked with %v",
1707+
r,
1708+
)
1709+
}
1710+
1711+
expectedError := badV.Check()
1712+
if !strings.Contains(rAsE.Error(), expectedError.Error()) {
1713+
t.Errorf(
1714+
"expected DerivativeWrt() to panic with error \"%v\"; it panicked with \"%v\"",
1715+
expectedError,
1716+
rAsE,
1717+
)
1718+
}
1719+
}()
1720+
1721+
mm.DerivativeWrt(badV)
1722+
}
1723+
1724+
/*
1725+
TestMonomialMatrix_DerivativeWrt3
1726+
Description:
1727+
1728+
This test checks that the DerivativeWrt() method properly returns a matrix of
1729+
monomials that are the derivatives of the original monomials with respect to
1730+
the given variable.
1731+
*/
1732+
func TestMonomialMatrix_DerivativeWrt3(t *testing.T) {
1733+
// Constants
1734+
v1 := symbolic.NewVariable()
1735+
v2 := symbolic.NewVariable()
1736+
m1 := v1.ToMonomial()
1737+
m2 := v2.ToMonomial()
1738+
var mm symbolic.MonomialMatrix = [][]symbolic.Monomial{
1739+
{m1, m2},
1740+
{m1, m2},
1741+
}
1742+
1743+
// Test
1744+
derivatives := mm.DerivativeWrt(v1)
1745+
1746+
// Check that the dimensions of the derivatives are (2,2)
1747+
if dims := derivatives.Dims(); dims[0] != 2 || dims[1] != 2 {
1748+
t.Errorf(
1749+
"expected DerivativeWrt() to return a MonomialMatrix with dimensions (2,2); received %v",
1750+
dims,
1751+
)
1752+
}
1753+
1754+
dAsMM, ok := derivatives.(symbolic.KMatrix)
1755+
if !ok {
1756+
t.Errorf(
1757+
"expected DerivativeWrt() to return a MonomialMatrix; received %T",
1758+
derivatives,
1759+
)
1760+
1761+
}
1762+
1763+
// Check that the derivatives are correct
1764+
for ii, row := range dAsMM {
1765+
for jj, derivative := range row {
1766+
// Check that the derivative is the correct monomial
1767+
if ii == 0 {
1768+
if jj == 0 {
1769+
if float64(derivative) != 1.0 {
1770+
t.Errorf(
1771+
"expected DerivativeWrt() to return a MonomialMatrix with derivative %v at (0,0); received %v",
1772+
1.0,
1773+
derivative,
1774+
)
1775+
}
1776+
} else {
1777+
if float64(derivative) != 0.0 {
1778+
t.Errorf(
1779+
"expected DerivativeWrt() to return a MonomialMatrix with derivative 0.0 at (0,1); received %v",
1780+
derivative,
1781+
)
1782+
}
1783+
}
1784+
} else {
1785+
if jj == 0 {
1786+
if float64(derivative) != 1.0 {
1787+
t.Errorf(
1788+
"expected DerivativeWrt() to return a MonomialMatrix with derivative %v at (1,1); received %v",
1789+
1.0,
1790+
derivative,
1791+
)
1792+
}
1793+
} else {
1794+
if float64(derivative) != 0.0 {
1795+
t.Errorf(
1796+
"expected DerivativeWrt() to return a MonomialMatrix with derivative 0.0 at (1,0); received %v",
1797+
derivative,
1798+
)
1799+
}
1800+
}
1801+
}
1802+
}
1803+
}
1804+
}
1805+
16371806
/*
16381807
TestMonomialMatrix_Transpose2
16391808
Description:

testing/symbolic/monomial_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,8 @@ func TestMonomial_DerivativeWrt2(t *testing.T) {
14601460
// Compute DerivativeWrt
14611461
derivative := m1.DerivativeWrt(v1)
14621462

1463-
// Verify that the derivative is a monomial
1464-
derivativeAsM, tf := derivative.(symbolic.Monomial)
1463+
// Verify that the derivative is a constant (K)
1464+
derivativeAsK, tf := derivative.(symbolic.K)
14651465
if !tf {
14661466
t.Errorf(
14671467
"expected derivative to be a monomial; received %T",
@@ -1470,10 +1470,10 @@ func TestMonomial_DerivativeWrt2(t *testing.T) {
14701470
}
14711471

14721472
// Verify that the derivative is a constant
1473-
if derivativeAsM.Coefficient != 3.14 {
1473+
if float64(derivativeAsK) != 3.14 {
14741474
t.Errorf(
14751475
"expected derivative to be a constant; received %v",
1476-
derivativeAsM.Coefficient,
1476+
derivativeAsK,
14771477
)
14781478
}
14791479
}

testing/symbolic/monomial_vector_test.go

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,7 +1626,7 @@ func TestMonomialVector_Derivative3(t *testing.T) {
16261626
derivative := mv.DerivativeWrt(v1)
16271627

16281628
// Verify that the derivative is a K vector
1629-
if _, tf := derivative.(symbolic.MonomialVector); !tf {
1629+
if _, tf := derivative.(symbolic.KVector); !tf {
16301630
t.Errorf(
16311631
"expected derivative to be a MonomialVector; received %T",
16321632
derivative,
@@ -1635,20 +1635,12 @@ func TestMonomialVector_Derivative3(t *testing.T) {
16351635

16361636
// Verify that each element of the derivative is just the coefficient
16371637
// from the original monomial vector mv
1638-
for ii, monomial := range derivative.(symbolic.MonomialVector) {
1639-
// Check that the monomial is a constant
1640-
if !monomial.IsConstant() {
1641-
t.Errorf(
1642-
"expected monomial to be a constant; received %v",
1643-
monomial,
1644-
)
1645-
}
1646-
1647-
if monomial.Coefficient != mv[ii].Coefficient {
1638+
for ii, d_ii := range derivative.(symbolic.KVector) {
1639+
if float64(d_ii) != mv[ii].Coefficient {
16481640
t.Errorf(
16491641
"expected constant to be %v; received %v",
16501642
mv[ii].Coefficient,
1651-
monomial.Coefficient,
1643+
float64(d_ii),
16521644
)
16531645
}
16541646
}
@@ -1683,7 +1675,7 @@ func TestMonomialVector_Derivative4(t *testing.T) {
16831675
derivative := mv.DerivativeWrt(v1)
16841676

16851677
// Verify that the derivative is a K vector
1686-
d_v1, tf := derivative.(symbolic.MonomialVector)
1678+
d_v1, tf := derivative.(symbolic.KVector)
16871679
if !tf {
16881680
t.Errorf(
16891681
"expected derivative to be a MonomialVector; received %T",
@@ -1692,32 +1684,18 @@ func TestMonomialVector_Derivative4(t *testing.T) {
16921684
}
16931685

16941686
// Verify that the first element of the derivative is a constant and nonzero
1695-
if !d_v1[0].IsConstant() {
1696-
t.Errorf(
1697-
"expected derivative[0] to be a constant; received %v",
1698-
d_v1[0],
1699-
)
1700-
}
1701-
1702-
if d_v1[0].Coefficient != 3.14 {
1687+
if float64(d_v1[0]) != 3.14 {
17031688
t.Errorf(
17041689
"expected derivative[0].Coefficient to be 3.14; received %v",
1705-
d_v1[0].Coefficient,
1690+
d_v1[0],
17061691
)
17071692
}
17081693

17091694
// Verify that the second element of the derivative is a constant and zero
1710-
if !d_v1[1].IsConstant() {
1711-
t.Errorf(
1712-
"expected derivative[1] to be a constant; received %v",
1713-
d_v1[1],
1714-
)
1715-
}
1716-
1717-
if d_v1[1].Coefficient != 0 {
1695+
if float64(d_v1[1]) != 0 {
17181696
t.Errorf(
17191697
"expected derivative[1].Coefficient to be 0; received %v",
1720-
d_v1[1].Coefficient,
1698+
d_v1[1],
17211699
)
17221700
}
17231701

0 commit comments

Comments
 (0)