Skip to content

Commit b90c05f

Browse files
Merge pull request #4 from MatProGo-dev/kr-feature-multiply1
Redid Expressions So That Multiply(), Plus(), Eq(), LessEq(), Greater…
2 parents fd1cd3a + 0284582 commit b90c05f

39 files changed

+1191
-894
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,9 @@ code. Hopefully, this is avoided using this format.
157157
* [X] AtVec()
158158
* [ ] Write changes to all AtVec() methods to output both elements AND errors (so we can detect out of length calls)
159159
* [ ] Determine whether or not to keep the Solution and Solver() interfaces in this module. It seems like they can be solver-specific.
160-
* [ ] Introduce MatrixVar object
160+
* [ ] Introduce MatrixVar object
161+
* [ ] Add The Following to the Expression Interface
162+
* [ ] Comparison
163+
* [ ] LessEq
164+
* [ ] GreaterEq
165+
* [ ] Eq

optim/constant.go

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,47 +52,52 @@ func (c K) Constant() float64 {
5252

5353
// Plus adds the current expression to another and returns the resulting
5454
// expression
55-
func (c K) Plus(e interface{}, errors ...error) (ScalarExpression, error) {
55+
func (c K) Plus(rightIn interface{}, errors ...error) (Expression, error) {
5656
// Input Processing
5757
err := CheckErrors(errors)
5858
if err != nil {
5959
return c, err
6060
}
6161

62+
if IsExpression(rightIn) {
63+
rightAsE, _ := ToExpression(rightIn)
64+
err = CheckDimensionsInAddition(c, rightAsE)
65+
if err != nil {
66+
return c, err
67+
}
68+
}
69+
6270
// Switching based on input type
63-
switch e.(type) {
71+
switch right := rightIn.(type) {
6472
case K:
65-
eAsK, _ := e.(K)
66-
return K(c.Constant() + eAsK.Constant()), nil
73+
return K(c.Constant() + right.Constant()), nil
6774
case Variable:
68-
eAsVar := e.(Variable)
69-
return eAsVar.Plus(c)
75+
return right.Plus(c)
7076
case ScalarLinearExpr:
71-
eAsSLE := e.(ScalarLinearExpr)
72-
return eAsSLE.Plus(c)
77+
return right.Plus(c)
7378
case ScalarQuadraticExpression:
74-
return e.(ScalarQuadraticExpression).Plus(c) // Very compact, but potentially confusing to read?
79+
return right.Plus(c) // Very compact, but potentially confusing to read?
7580
default:
76-
return c, fmt.Errorf("Unexpected type in K.Plus() for constant %v: %T", e, e)
81+
return c, fmt.Errorf("Unexpected type in K.Plus() for constant %v: %T", right, right)
7782
}
7883
}
7984

8085
// LessEq returns a less than or equal to (<=) constraint between the
8186
// current expression and another
82-
func (c K) LessEq(rhsIn interface{}, errors ...error) (ScalarConstraint, error) {
83-
return c.Comparison(rhsIn, SenseLessThanEqual, errors...)
87+
func (c K) LessEq(rightIn interface{}, errors ...error) (Constraint, error) {
88+
return c.Comparison(rightIn, SenseLessThanEqual, errors...)
8489
}
8590

8691
// GreaterEq returns a greater than or equal to (>=) constraint between the
8792
// current expression and another
88-
func (c K) GreaterEq(rhsIn interface{}, errors ...error) (ScalarConstraint, error) {
89-
return c.Comparison(rhsIn, SenseGreaterThanEqual, errors...)
93+
func (c K) GreaterEq(rightIn interface{}, errors ...error) (Constraint, error) {
94+
return c.Comparison(rightIn, SenseGreaterThanEqual, errors...)
9095
}
9196

9297
// Eq returns an equality (==) constraint between the current expression
9398
// and another
94-
func (c K) Eq(rhsIn interface{}, errors ...error) (ScalarConstraint, error) {
95-
return c.Comparison(rhsIn, SenseEqual, errors...)
99+
func (c K) Eq(rightIn interface{}, errors ...error) (Constraint, error) {
100+
return c.Comparison(rightIn, SenseEqual, errors...)
96101
}
97102

98103
/*
@@ -101,7 +106,7 @@ Description:
101106
102107
This method compares the receiver with expression rhs in the sense provided by sense.
103108
*/
104-
func (c K) Comparison(rhsIn interface{}, sense ConstrSense, errors ...error) (ScalarConstraint, error) {
109+
func (c K) Comparison(rhsIn interface{}, sense ConstrSense, errors ...error) (Constraint, error) {
105110
// InputProcessing
106111
err := CheckErrors(errors)
107112
if err != nil {
@@ -237,3 +242,7 @@ func (c K) Dims() []int {
237242
func (c K) Check() error {
238243
return nil
239244
}
245+
246+
func (c K) Transpose() Expression {
247+
return c
248+
}

optim/constraint.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ Description:
88
*/
99

1010
type Constraint interface {
11+
Left() Expression
12+
Right() Expression
1113
}
1214

13-
func IsConstraint(c Constraint) bool {
15+
func IsConstraint(c interface{}) bool {
1416
switch c.(type) {
1517
case ScalarConstraint:
1618
return true

optim/expression.go

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
)
66

77
/*
8-
expression.go
8+
matrix_expression.go
99
Description:
1010
This file holds all of the functions and methods related to the Expression
1111
interface.
@@ -27,25 +27,31 @@ type Expression interface {
2727
// Dims returns a slice describing the true dimensions of a given expression (scalar, vector, or matrix)
2828
Dims() []int
2929

30-
//// Plus adds the current expression to another and returns the resulting
31-
//// expression
32-
// Plus(e Expression, extras interface{}) (Expression, error)
33-
//
34-
//// Mult multiplies the current expression to another and returns the
35-
//// resulting expression
36-
//Multiply(c interface{}) (Expression, error)
37-
//
38-
//// LessEq returns a less than or equal to (<=) constraint between the
39-
//// current expression and another
40-
//LessEq(e Expression) Constraint
41-
//
42-
//// GreaterEq returns a greater than or equal to (>=) constraint between the
43-
//// current expression and another
44-
//GreaterEq(e Expression) Constraint
45-
//
46-
//// Eq returns an equality (==) constraint between the current expression
47-
//// and another
48-
//Eq(e ScalarExpression) *ScalarConstraint
30+
// Plus adds the current expression to another and returns the resulting
31+
// expression
32+
Plus(e interface{}, errors ...error) (Expression, error)
33+
34+
// Multiply multiplies the current expression to another and returns the
35+
// resulting expression
36+
Multiply(c interface{}, errors ...error) (Expression, error)
37+
38+
// Transpose transposes the given expression
39+
Transpose() Expression
40+
41+
// LessEq returns a less than or equal to (<=) constraint between the
42+
// current expression and another
43+
LessEq(rightIn interface{}, errors ...error) (Constraint, error)
44+
45+
// GreaterEq returns a greater than or equal to (>=) constraint between the
46+
// current expression and another
47+
GreaterEq(rightIn interface{}, errors ...error) (Constraint, error)
48+
49+
// Eq returns an equality (==) constraint between the current expression
50+
// and another
51+
Eq(rightIn interface{}, errors ...error) (Constraint, error)
52+
53+
// Comparison
54+
Comparison(rightIn interface{}, sense ConstrSense, errors ...error) (Constraint, error)
4955
}
5056

5157
/*
@@ -82,3 +88,20 @@ func CheckDimensionsInMultiplication(left, right Expression) error {
8288
// If dimensions match, then return nothing.
8389
return nil
8490
}
91+
92+
func CheckDimensionsInAddition(left, right Expression) error {
93+
// Check that the size of columns in left and right agree
94+
dimsAreMatched := (left.Dims()[0] == right.Dims()[0]) && (left.Dims()[1] == right.Dims()[1])
95+
dimsAreMatched = dimsAreMatched || IsScalarExpression(left)
96+
dimsAreMatched = dimsAreMatched || IsScalarExpression(right)
97+
98+
if !dimsAreMatched {
99+
return DimensionError{
100+
Operation: "Plus",
101+
Arg1: left,
102+
Arg2: right,
103+
}
104+
}
105+
// If dimensions match, then return nothing.
106+
return nil
107+
}

optim/model.go

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -132,48 +132,40 @@ func (m *Model) AddBinaryVariableMatrix(rows, cols int) [][]Variable {
132132
}
133133

134134
// AddConstr adds the given constraint to the model.
135-
func (m *Model) AddConstraint(constr Constraint, extras ...interface{}) error {
135+
func (m *Model) AddConstraint(constr Constraint, errors ...error) error {
136136
// Constants
137-
nExtraArguments := len(extras)
138137

139138
// Input Processing
140-
switch {
141-
case nExtraArguments > 1:
142-
// Do nothing, but report an error
143-
return fmt.Errorf(
144-
"The optimizer tried to add a constraint using a bad call to AddConstr! Skipping this constraint: %v , because of extra inputs %v",
145-
constr,
146-
extras,
147-
)
148-
case nExtraArguments == 1:
149-
switch extra0 := extras[0].(type) {
150-
case error:
151-
if extra0 != nil {
152-
return fmt.Errorf(
153-
"There was an error computing constraint %v: %v",
154-
constr, extra0,
155-
)
156-
}
157-
case nil:
158-
// Do nothing
159-
default:
160-
return fmt.Errorf(
161-
"There was an unexpected type input to AddConstraint(): %T (%v)",
162-
extra0, extra0,
163-
)
164-
}
139+
err := CheckErrors(errors)
140+
if err != nil {
141+
return err
165142
}
166-
// If no extras are given, then move on to last part.
167143

168144
// Algorithm
169145
m.Constraints = append(m.Constraints, constr)
170146
return nil
171147
}
172148

173-
// SetObjective sets the objective of the model given an expression and
174-
// objective sense.
175-
func (m *Model) SetObjective(e ScalarExpression, sense ObjSense) {
176-
m.Obj = NewObjective(e, sense)
149+
/*
150+
SetObjective
151+
Description:
152+
sets the objective of the model given an expression and
153+
objective sense.
154+
Notes:
155+
To make this function easier to parse, we will assume an expression
156+
is given, even though objectives are normally scalars.
157+
*/
158+
159+
func (m *Model) SetObjective(e Expression, sense ObjSense) error {
160+
// Input Processing
161+
se, err := ToScalarExpression(e)
162+
if err != nil {
163+
return fmt.Errorf("trouble parsing input expression: %v", err)
164+
}
165+
166+
// Return
167+
m.Obj = NewObjective(se, sense)
168+
return nil
177169
}
178170

179171
//// Optimize optimizes the model using the given solver type and returns the

optim/operators.go

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func LessEq(lhs, rhs interface{}) (Constraint, error) {
2727
}
2828

2929
// GreaterEq returns a constraint representing lhs >= rhs
30-
func GreaterEq(lhs, rhs ScalarExpression) (Constraint, error) {
30+
func GreaterEq(lhs, rhs interface{}) (Constraint, error) {
3131
return Comparison(lhs, rhs, SenseGreaterThanEqual)
3232
}
3333

@@ -42,31 +42,24 @@ Usage:
4242
constr, err := Comparison(expr1, expr2, SenseGreaterThanEqual)
4343
*/
4444
func Comparison(lhs, rhs interface{}, sense ConstrSense) (Constraint, error) {
45-
// Constants
45+
// Input Processing
46+
var err error
47+
left0, err := ToExpression(lhs)
48+
if err != nil {
49+
return ScalarConstraint{}, fmt.Errorf("lhs is not a valid expression: %v", err)
50+
}
4651

4752
// Algorithm
48-
switch lhs0 := lhs.(type) {
49-
case float64:
50-
// Convert lhs to K
51-
lhsAsK := K(lhs0)
52-
53-
// Create constraint
54-
return Comparison(lhsAsK, rhs, sense)
55-
case mat.VecDense:
56-
// Convert lhs to KVector.
57-
lhsAsKVector := KVector(lhs0)
58-
59-
// Create constraint
60-
return lhsAsKVector.Comparison(rhs, sense)
53+
switch left := left0.(type) {
6154
case ScalarExpression:
6255
rhsAsScalarExpression, _ := rhs.(ScalarExpression)
6356
return ScalarConstraint{
64-
lhs0,
57+
left,
6558
rhsAsScalarExpression,
6659
sense,
6760
}, nil
6861
case VectorExpression:
69-
return lhs0.Comparison(rhs, sense)
62+
return left.Comparison(rhs, sense)
7063
default:
7164
return nil, fmt.Errorf("Comparison in sense '%v' is not defined for lhs type %T and rhs type %T!", sense, lhs, rhs)
7265
}

optim/scalar_constraint.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ type ScalarConstraint struct {
99
Sense ConstrSense
1010
}
1111

12+
func (sc ScalarConstraint) Left() Expression {
13+
return sc.LeftHandSide
14+
}
15+
16+
func (sc ScalarConstraint) Right() Expression {
17+
return sc.RightHandSide
18+
}
19+
1220
// ConstrSense represents if the constraint x <= y, x >= y, or x == y. For easy
1321
// integration with Gurobi, the senses have been encoding using a byte in
1422
// the same way Gurobi encodes the constraint senses.

optim/scalar_expression.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,38 @@ type ScalarExpression interface {
2626

2727
// Plus adds the current expression to another and returns the resulting
2828
// expression
29-
Plus(e interface{}, errors ...error) (ScalarExpression, error)
29+
Plus(rightIn interface{}, errors ...error) (Expression, error)
3030

3131
// Mult multiplies the current expression to another and returns the
3232
// resulting expression
3333
//Mult(c float64, errors ...error) (Expression, error)
3434

3535
// LessEq returns a less than or equal to (<=) constraint between the
3636
// current expression and another
37-
LessEq(rhsIn interface{}, errors ...error) (ScalarConstraint, error)
37+
LessEq(rhsIn interface{}, errors ...error) (Constraint, error)
3838

3939
// GreaterEq returns a greater than or equal to (>=) constraint between the
4040
// current expression and another
41-
GreaterEq(rhsIn interface{}, errors ...error) (ScalarConstraint, error)
41+
GreaterEq(rhsIn interface{}, errors ...error) (Constraint, error)
4242

4343
// Eq returns an equality (==) constraint between the current expression
4444
// and another
45-
Eq(rhsIn interface{}, errors ...error) (ScalarConstraint, error)
45+
Eq(rhsIn interface{}, errors ...error) (Constraint, error)
4646

4747
//Comparison
4848
// Compares the receiver expression rhs with the expression rhs in the sense of sense.
49-
Comparison(rhsIn interface{}, sense ConstrSense, errors ...error) (ScalarConstraint, error)
49+
Comparison(rhsIn interface{}, sense ConstrSense, errors ...error) (Constraint, error)
5050

5151
//Multiply
5252
// Multiplies the given scalar expression with another expression
53-
Multiply(term1 interface{}, errors ...error) (Expression, error)
53+
Multiply(rightIn interface{}, errors ...error) (Expression, error)
5454

5555
//Dims
5656
// Returns the dimensions of the scalar expression (should always be 1,1)
5757
Dims() []int
58+
59+
//Transpose returns the transpose of the given vector expression
60+
Transpose() Expression
5861
}
5962

6063
// NewExpr returns a new expression with a single additive constant value, c,

0 commit comments

Comments
 (0)