Skip to content

Commit 87250bd

Browse files
author
Kwesi Rutledge
committed
Updated ConstantMatrix.Multiply to accept more types + added tests for more Polynomial functions
1 parent f41ff0a commit 87250bd

File tree

3 files changed

+554
-10
lines changed

3 files changed

+554
-10
lines changed

symbolic/constant_matrix.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,32 @@ func (km KMatrix) Multiply(e interface{}) Expression {
274274

275275
case K:
276276
return km.Multiply(float64(right)) // Reuse float64 case
277+
case Polynomial:
278+
// Choose the correct output type based on the size of km
279+
nR, nC := km.Dims()[0], km.Dims()[1]
280+
switch {
281+
case (nR == 1) && (nC == 1):
282+
// If the output is a scalar, return a scalar
283+
return km[0][0].Multiply(right)
284+
case nC == 1:
285+
// If the output is a vector, return a vector
286+
var outputVec PolynomialVector = make([]Polynomial, nR)
287+
for rIndex := 0; rIndex < nR; rIndex++ {
288+
outputVec[rIndex] = km[rIndex][0].Multiply(right.Copy()).(Polynomial)
289+
}
290+
return outputVec
291+
default:
292+
// If the output is a matrix, return a matrix
293+
var outputMat PolynomialMatrix = make([][]Polynomial, nR)
294+
for rIndex := 0; rIndex < nR; rIndex++ {
295+
outputMat[rIndex] = make([]Polynomial, nC)
296+
for cIndex := 0; cIndex < nC; cIndex++ {
297+
outputMat[rIndex][cIndex] = km[rIndex][cIndex].Multiply(right.Copy()).(Polynomial)
298+
}
299+
}
300+
return outputMat
301+
}
302+
277303
case *mat.VecDense:
278304
// Use gonum's built-in multiplication function
279305
var product mat.VecDense

symbolic/polynomial.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,6 @@ func (p Polynomial) Minus(e interface{}) Expression {
239239
return Minus(p, eAsE)
240240
}
241241

242-
// Constants
243-
switch right := e.(type) {
244-
case float64:
245-
return p.Copy().Minus(K(right))
246-
}
247-
248242
// If the function has reached this point, then
249243
// the input is not recognized
250244
panic(
@@ -429,7 +423,13 @@ Description:
429423
The transpose operator when applied to a scalar is just the same scalar object.
430424
*/
431425
func (p Polynomial) Transpose() Expression {
432-
return p
426+
// Input Processing
427+
err := p.Check()
428+
if err != nil {
429+
panic(err)
430+
}
431+
432+
return p.Copy()
433433
}
434434

435435
/*
@@ -624,8 +624,19 @@ func (p Polynomial) DerivativeWrt(vIn Variable) Expression {
624624
}
625625

626626
// Append
627-
components := monomial.DerivativeWrt(vIn)
628-
derivative.Monomials = append(derivative.Monomials, components.(Monomial))
627+
dMonomial := monomial.DerivativeWrt(vIn)
628+
switch component := dMonomial.(type) {
629+
case Monomial:
630+
derivative.Monomials = append(derivative.Monomials, component)
631+
case K:
632+
// Skip zero monomials
633+
if float64(component) == 0.0 {
634+
continue
635+
}
636+
derivative.Monomials = append(derivative.Monomials, component.ToMonomial())
637+
default:
638+
panic(fmt.Errorf("Unexpected type in Polynomial.Derivative: %T", component))
639+
}
629640
}
630641

631642
// If the derivative is empty, then return 0.0
@@ -780,7 +791,7 @@ func (p Polynomial) Substitute(vIn Variable, eIn ScalarExpression) Expression {
780791
var out Expression = K(0.0)
781792
for _, monomial := range p.Monomials {
782793
newMonomial := monomial.Substitute(vIn, eIn)
783-
out = out.Plus(newMonomial)
794+
out = out.Plus(newMonomial).(Polynomial).Simplify()
784795
}
785796

786797
return out

0 commit comments

Comments
 (0)