From d117ff1511aa24485641a6d0628fc40a538f1291 Mon Sep 17 00:00:00 2001 From: Cecil Date: Sun, 14 Oct 2018 23:06:40 -0600 Subject: [PATCH] for #410, improvements to user layout delation * beqin conversion of cassowary solver from python to ruby --- Tests/layout/l2.rb | 62 +- crontasks/nbuild-shoes-all | 4 +- lib/cassowary/__init__.py | 16 + lib/cassowary/constraint.rb | 238 +++++++ lib/cassowary/edit_info.py | 19 + lib/cassowary/error.py | 17 + lib/cassowary/expression.py | 601 ++++++++++++++++ lib/cassowary/expression.rb | 270 ++++++++ lib/cassowary/simplex_solver.py | 612 ++++++++++++++++ lib/cassowary/simplex_solver.rb | 769 ++++++++++++++++++++ lib/cassowary/tableau.py | 109 +++ lib/cassowary/tableau.rb | 177 +++++ lib/cassowary/test.rb | 31 + lib/cassowary/tests/__init__.py | 0 lib/cassowary/tests/test_constraint.py | 166 +++++ lib/cassowary/tests/test_end_to_end.py | 770 +++++++++++++++++++++ lib/cassowary/tests/test_expression.py | 283 ++++++++ lib/cassowary/tests/test_simplex_solver.py | 69 ++ lib/cassowary/tests/test_tableau.py | 22 + lib/cassowary/tests/test_variable.py | 133 ++++ lib/cassowary/utils.py | 30 + lib/cassowary/utils.rb | 38 + lib/cassowary/variable.rb | 163 +++++ shoes/canvas.c | 13 +- shoes/canvas.h | 1 + shoes/types/layout.c | 70 +- shoes/types/layout.h | 19 +- 27 files changed, 4632 insertions(+), 70 deletions(-) create mode 100644 lib/cassowary/__init__.py create mode 100644 lib/cassowary/constraint.rb create mode 100644 lib/cassowary/edit_info.py create mode 100644 lib/cassowary/error.py create mode 100644 lib/cassowary/expression.py create mode 100644 lib/cassowary/expression.rb create mode 100644 lib/cassowary/simplex_solver.py create mode 100644 lib/cassowary/simplex_solver.rb create mode 100644 lib/cassowary/tableau.py create mode 100644 lib/cassowary/tableau.rb create mode 100644 lib/cassowary/test.rb create mode 100644 lib/cassowary/tests/__init__.py create mode 100644 lib/cassowary/tests/test_constraint.py create mode 100644 lib/cassowary/tests/test_end_to_end.py create mode 100644 lib/cassowary/tests/test_expression.py create mode 100644 lib/cassowary/tests/test_simplex_solver.py create mode 100644 lib/cassowary/tests/test_tableau.py create mode 100644 lib/cassowary/tests/test_variable.py create mode 100644 lib/cassowary/utils.py create mode 100644 lib/cassowary/utils.rb create mode 100644 lib/cassowary/variable.rb diff --git a/Tests/layout/l2.rb b/Tests/layout/l2.rb index 63217f27..c76279c8 100644 --- a/Tests/layout/l2.rb +++ b/Tests/layout/l2.rb @@ -1,7 +1,7 @@ class MyLayout attr_accessor :pos_x, :pos_y, :w, :h - attr_accessor :incr_x, :incr_y + attr_accessor :incr_x, :incr_y, :canvas def initialize() puts "initialized" @@ -9,49 +9,46 @@ def initialize() end def setup(canvas, attr) + @canvas = canvas @w = attr[:width] @h = attr[:height] puts "callback: setup #{@w} X #{@h}" end def add(canvas, widget) - puts "callback add: #{widget.inspect} #{canvas.contents.size}" + puts "callback add: #{widget.class} #{canvas.contents.size}" + puts "w: #{widget.width} h: #{widget.height}" + widget.move @pos_x, @pos_y @pos_x += @incr_x - if @pos_x < 0 - @pos_x = 0 - @incr_x = 25 - end - if @pos_x >= @w - @pos_x = @w - @incr_x = -25 - end @pos_y += @incr_y - if @pos_y <= 0 - @pos_y = 0 - @incr_y = +25 - end - if @pos_y >= @h - @pos_y = @h - 25 - @incr_y = -25 - end - widget.move @pos_x, @pos_y end def clear - @pos_x = -20 - @pos_y = -20 + @pos_x = 5 + @pos_y = 5 @incr_x = 25 @incr_y = 25 puts "callback: clear" end + def refresh + @pos_x = 5 + @pos_y = 5 + @canvas.contents.each do |widget| + widget.move @pos_x, @pos_y + @pos_x += @incr_x + @pos_y += @incr_y + puts "w: #{widget.width} h: #{widget.height}" + end + end + end -Shoes.app width: 350, height: 450, resizeable: true do +Shoes.app width: 380, height: 450, resizeable: true do stack do @p = para "Before layout" @ml = MyLayout.new - @lay =layout manager: @ml, width: 340, height: 380 do + @lay = layout manager: @ml, width: 340, height: 380 do background yellow p1 = para "First Para" a = button "one" @@ -61,17 +58,22 @@ def clear @p.text = @lay.inspect @lay.finish end - button "Append" do - @lay.append { para "appended" } + @el = edit_line width:40 + @el.text = '-1' + button "Insert" do + @lay.insert @el.text.to_i do + para "inserted #{@el.text}" + end + end + button "delete_at" do + @lay.delete_at @el.text.to_i do + para "replaced by deletion" + end end button "Clear" do @lay.clear { background white } end - button "Prepend" do - # problem here? - @lay.prepend { para "prepended" } - end button "refresh" do - @lay.refresh + @ml.refresh end end diff --git a/crontasks/nbuild-shoes-all b/crontasks/nbuild-shoes-all index d25daec7..392351ac 100755 --- a/crontasks/nbuild-shoes-all +++ b/crontasks/nbuild-shoes-all @@ -3,8 +3,8 @@ # my crontab so user name may not be set up for chroot #schroot -c debx86 -u ccoupe -- ~/Projects/shoes3/crontasks/nbuild-shoes-lin64 ~/Projects/shoes3/crontasks/nbuild-shoes-xlin64 -#~/Projects/shoes3/crontasks/nbuild-shoes-xwin7 -~/Projects/shoes3/crontasks/nbuild-shoes-mxe +~/Projects/shoes3/crontasks/nbuild-shoes-xwin7 +#~/Projects/shoes3/crontasks/nbuild-shoes-mxe # ssh to Mac mini and build. #~/Projects/shoes3/crontasks/nbuild-shoes-mavericks ~/Projects/shoes3/crontasks/nbuild-shoes-yosemite diff --git a/lib/cassowary/__init__.py b/lib/cassowary/__init__.py new file mode 100644 index 00000000..d276f37d --- /dev/null +++ b/lib/cassowary/__init__.py @@ -0,0 +1,16 @@ +from __future__ import print_function, unicode_literals, absolute_import + +from .expression import Variable +from .error import RequiredFailure, ConstraintNotFound, InternalError +from .simplex_solver import SimplexSolver +from .utils import REQUIRED, STRONG, MEDIUM, WEAK + +# Examples of valid version strings +# __version__ = '1.2.3.dev1' # Development release 1 +# __version__ = '1.2.3a1' # Alpha Release 1 +# __version__ = '1.2.3b1' # Beta Release 1 +# __version__ = '1.2.3rc1' # RC Release 1 +# __version__ = '1.2.3' # Final Release +# __version__ = '1.2.3.post1' # Post Release 1 + +__version__ = '0.5.1' diff --git a/lib/cassowary/constraint.rb b/lib/cassowary/constraint.rb new file mode 100644 index 00000000..7fe82134 --- /dev/null +++ b/lib/cassowary/constraint.rb @@ -0,0 +1,238 @@ +module Cassowary +########################################################################### +# Constraint +# +# Constraints are the restrictions on linear programming; an equality or +# inequality between two expressions. +########################################################################### +=begin +class AbstractConstraint(object): + def __init__(self, strength, weight=1.0): + self.strength = strength + self.weight = weight + self.is_edit_constraint = False + self.is_inequality = False + self.is_stay_constraint = False + + @property + def is_required(self): + return self.strength == REQUIRED + + def __repr__(self): + return '%s:{%s}(%s)' % (repr_strength(self.strength), self.weight, self.expression) +=end + class AbstractConstraint + attr_accessor :strength, :weight, :s_edit_constraint, :is_inequality, + :is_stay_contraint + + def initialize(strength, weight=1.0) + @strength = strength + @weight = weight + @is_edit_constraint = false + @is_inequality = false + @is_stay_contraint = false + end + + def is_required + @strength == REQUIRED + end + end + + class EditConstraint < AbstractConstraint + attr_accessor :variable, :expression + + def initialize(variable, strength=STRONG, weight=1.0) + super(strength, weight) + @variable = variable + @expression = Expression.new(variable, -1.0, variable.value) + @is_edit_constraint = true + end + end + + class StayConstraint < AbstractConstraint + attr_accessor :variable, :expression + + def initialize(variable, strength=STRONG, weight=1.0) + super(strength, weight) + @variable = variable + @expression = Expression.new(variable, -1.0, variable.value) + @is_stay_constraint = true + end + end + +=begin +class EditConstraint(AbstractConstraint): + def __init__(self, variable, strength=STRONG, weight=1.0): + super(EditConstraint, self).__init__(strength, weight) + self.variable = variable + self.expression = Expression(variable, -1.0, variable.value) + self.is_edit_constraint = True + + def __repr__(self): + return 'edit:%s' % super(EditConstraint, self).__repr__() + + +class StayConstraint(AbstractConstraint): + def __init__(self, variable, strength=STRONG, weight=1.0): + super(StayConstraint, self).__init__(strength, weight) + self.variable = variable + self.expression = Expression(variable, -1.0, variable.value) + self.is_stay_constraint=True + + def __repr__(self): + return 'stay:%s' % super(StayConstraint, self).__repr__() + +=end + + class Constraint < AbstractConstraint + LEQ = -1 + EQ = 0 + GEQ = 1 + + def initialize(param1, operator=EQ, param2=None, strength=REQUIRED, weight=1.0) + # Define a new linear constraint. + # + # param1 may be an expression or variable + # param2 may be an expression, variable, or constant, or may be ommitted entirely. + # If param2 is specified, the operator must be either LEQ, EQ, or GEQ + + if param1.kind_of? Expression + if param2 == nil + super(strength=strength, weight=weight) + @expression = param1 + elsif param2.kind_of? Expression + super(strength=strength, weight=weight) + @expression = param1.clone() + if operator == self.LEQ + @expression.multiply(-1.0) + @expression.add_expression(param2, 1.0) + elsif operator == self.EQ + @expression.add_expression(param2, -1.0) + elsif operator == self.GEQ + @expression.add_expression(param2, -1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + elsif param2.kind_of? Variable + super(strength=strength, weight=weight) + @expression = param1.clone() + if operator == self.LEQ + @expression.multiply(-1.0) + @expression.add_variable(param2, 1.0) + elsif operator == self.EQ + @expression.add_variable(param2, -1.0) + elsif operator == self.GEQ + @expression.add_variable(param2, -1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + elsif param2.kind_of? Numeric + super(strength=strength, weight=weight) + @expression = param1.clone() + if operator == self.LEQ + @expression.multiply(-1.0) + @expression.add_expression(Expression(constant=param2), 1.0) + elsif operator == self.EQ + @expression.add_expression(Expression(constant=param2), -1.0) + elsif operator == self.GEQ + @expression.add_expression(Expression(constant=param2), -1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + else + raise InternalError("Invalid parameters to Constraint constructor") + end + elsif param1.kind_of? Variable + if param2 == nil + super(strength=strength, weight=weight) + @expression = Expression.new(param1) + elsif param2.kind_of? Expression + super(strength=strength, weight=weight) + @expression = param2.clone() + if operator == self.LEQ + @expression.add_variable(param1, -1.0) + elsif operator == self.EQ + @expression.add_variable(param1, -1.0) + elsif operator == self.GEQ + @expression.multiply(-1.0) + @expression.add_variable(param1, 1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + elsif param2.kind_of? Variable + super(strength=strength, weight=weight) + @expression = Expression.new(param2) + if operator == self.LEQ + @expression.add_variable(param1, -1.0) + elsif operator == self.EQ + @expression.add_variable(param1, -1.0) + elsif operator == self.GEQ + @expression.multiply(-1.0) + @expression.add_variable(param1, 1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + elsif param2.kind_of? Numeric + super(strength=strength, weight=weight) + @expression = Expression.new(constant=param2) + if operator == self.LEQ + @expression.add_variable(param1, -1.0) + elsif operator == self.EQ + @expression.add_variable(param1, -1.0) + elsif operator == self.GEQ + @expression.multiply(-1.0) + @expression.add_variable(param1, 1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + else + raise InternalError("Invalid parameters to Constraint constructor") + end + elsif param1.kind_of? Numeric + if param2 == nil + super(strength=strength, weight=weight) + @expression = Expression(constant=param1) + elsif param2.kind_of? Expression + super(strength=strength, weight=weight) + @expression = param2.clone() + if operator == self.LEQ + @expression.add_expression(Expression.new(constant=param1), -1.0) + elsif operator == self.EQ + @expression.add_expression(Expression.new(constant=param1), -1.0) + elsif operator == self.GEQ + @expression.multiply(-1.0) + @expression.add_expression(Expression.new(constant=param1), 1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + elsif param2.kind_of? Variable + super(strength=strength, weight=weight) + @expression = Expression,new(constant=param1) + if operator == self.LEQ + @expression.add_variable(param2, -1.0) + elsif operator == self.EQ + @expression.add_variable(param2, -1.0) + elsif operator == self.GEQ + @expression.multiply(-1.0) + @expression.add_variable(param2, 1.0) + else + raise InternalError("Invalid operator in Constraint constructor") + end + elsif param2.kind_of? Numeric + raise InternalError("Cannot create an inequality between constants") + else + raise InternalError("Invalid parameters to Constraint constructor") + end + else + raise InternalError("Invalid parameters to Constraint constructor") + end + @is_inequality = operator != self.EQ + end + + def clone() + c = Constraint(@expression, strength=@strength, weight=@weight) + c.is_inequality = @is_inequality + return c + end + end # class Constraint +end # module Cassowary diff --git a/lib/cassowary/edit_info.py b/lib/cassowary/edit_info.py new file mode 100644 index 00000000..85e3802b --- /dev/null +++ b/lib/cassowary/edit_info.py @@ -0,0 +1,19 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + + +class EditInfo(object): + def __init__(self, constraint, edit_plus, edit_minus, prev_edit_constant, index): + self.constraint = constraint + self.edit_plus = edit_plus + self.edit_minus = edit_minus + self.prev_edit_constant = prev_edit_constant + self.index = index + + def __repr__(self): + return '' % ( + self.constraint, + self.edit_plus, + self.edit_minus, + self.prev_edit_constant, + self.index + ) diff --git a/lib/cassowary/error.py b/lib/cassowary/error.py new file mode 100644 index 00000000..34e0e9a8 --- /dev/null +++ b/lib/cassowary/error.py @@ -0,0 +1,17 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + + +class CassowaryException(Exception): + pass + + +class InternalError(CassowaryException): + pass + + +class ConstraintNotFound(CassowaryException): + pass + + +class RequiredFailure(CassowaryException): + pass diff --git a/lib/cassowary/expression.py b/lib/cassowary/expression.py new file mode 100644 index 00000000..5ead821a --- /dev/null +++ b/lib/cassowary/expression.py @@ -0,0 +1,601 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +from .error import InternalError +from .utils import approx_equal, REQUIRED, STRONG, repr_strength + +########################################################################### +# Variables +# +# Variables are the atomic unit of linear programming, describing the +# quantities that are to be solved and constrained. +########################################################################### + +class AbstractVariable(object): + def __init__(self, name): + self.name = name + self.is_dummy = False + self.is_external = False + self.is_pivotable = False + self.is_restricted = False + + def __rmul__(self, x): + return self.__mul__(x) + + def __mul__(self, x): + if isinstance(x, (float, int)): + return Expression(self, x) + elif isinstance(x, Expression): + if x.is_constant: + return Expression(self, value=x.constant) + else: + return NotImplemented + else: + return NotImplemented + + def __truediv__(self, x): + return self.__div__(x) + + def __div__(self, x): + if isinstance(x, (float, int)): + if approx_equal(x, 0): + raise ZeroDivisionError() + return Expression(self, 1.0 / x) + elif isinstance(x, Expression): + if x.is_constant: + return Expression(self, value=1.0/x.constant) + else: + return NotImplemented + else: + return NotImplemented + + def __radd__(self, x): + return self.__add__(x) + + def __add__(self, x): + if isinstance(x, (int, float)): + return Expression(self, constant=x) + elif isinstance(x, Expression): + return Expression(self) + x + elif isinstance(x, AbstractVariable): + return Expression(self) + Expression(x) + else: + return NotImplemented + + def __rsub__(self, x): + if isinstance(x, (int, float)): + return Expression(self, -1.0, constant=x) + elif isinstance(x, Expression): + return x - Expression(self) + elif isinstance(x, AbstractVariable): + return Expression(x) - Expression(self) + else: + return NotImplemented + + def __sub__(self, x): + if isinstance(x, (int, float)): + return Expression(self, constant=-x) + elif isinstance(x, Expression): + return Expression(self) - x + elif isinstance(x, AbstractVariable): + return Expression(self) - Expression(x) + else: + return NotImplemented + + +class Variable(AbstractVariable): + def __init__(self, name, value=0.0): + super(Variable, self).__init__(name) + self.value = float(value) + self.is_external = True + + def __repr__(self): + return '%s[%s]' % (self.name, self.value) + + __hash__ = object.__hash__ + + def __eq__(self, other): + if isinstance(other, (Expression, Variable, float, int)): + return Constraint(self, Constraint.EQ, other) + else: + return NotImplemented + + def __lt__(self, other): + # < and <= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + return self.__le__(other) + + def __le__(self, other): + if isinstance(other, (Expression, Variable, float, int)): + return Constraint(self, Constraint.LEQ, other) + else: + return NotImplemented + + def __gt__(self, other): + # > and >= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + return self.__ge__(other) + + def __ge__(self, other): + if isinstance(other, (Expression, Variable, float, int)): + return Constraint(self, Constraint.GEQ, other) + else: + return NotImplemented + + +class DummyVariable(AbstractVariable): + def __init__(self, number): + super(DummyVariable, self).__init__(name='d%s' % (number)) + self.is_dummy = True + self.is_restricted = True + + def __repr__(self): + return '%s:dummy' % self.name + + +class ObjectiveVariable(AbstractVariable): + def __init__(self, name): + super(ObjectiveVariable, self).__init__(name) + + def __repr__(self): + return '%s:obj' % self.name + + +class SlackVariable(AbstractVariable): + def __init__(self, prefix, number): + super(SlackVariable, self).__init__(name='%s%s' % (prefix, number)) + self.is_pivotable = True + self.is_restricted = True + + def __repr__(self): + return '%s:slack' % self.name + +########################################################################### +# Expressions +# +# Expressions are combinations of variables with multipliers and constants +########################################################################### + + +class Expression(object): + def __init__(self, variable=None, value=1.0, constant=0.0): + assert isinstance(constant, (float, int)) + assert variable is None or isinstance(variable, AbstractVariable) + + self.constant = float(constant) + self.terms = {} + + if variable: + self.set_variable(variable, float(value)) + + def __repr__(self): + parts = [] + if not approx_equal(self.constant, 0.0) or self.is_constant: + parts.append(repr(self.constant)) + for clv, coeff in sorted(self.terms.items(), key=lambda x:repr(x)): + if approx_equal(coeff, 1.0): + parts.append(repr(clv)) + else: + parts.append(repr(coeff) + "*" + repr(clv)) + return ' + '.join(parts) + + @property + def is_constant(self): + return not self.terms + + def clone(self): + expr = Expression(constant=self.constant) + for clv, value in self.terms.items(): + expr.set_variable(clv, value) + return expr + + ###################################################################### + # Mathematical operators + ###################################################################### + + def __rmul__(self, x): + return self.__mul__(x) + + def __mul__(self, x): + if isinstance(x, Expression): + if self.is_constant: + result = x * self.constant + elif x.is_constant: + result = self * x.constant + else: + return NotImplemented + elif isinstance(x, Variable): + if self.is_constant: + result = Expression(x, self.constant) + else: + return NotImplemented + elif isinstance(x, (float, int)): + result = Expression(constant=self.constant * x) + for clv, value in self.terms.items(): + result.set_variable(clv, value * x) + else: + return NotImplemented + return result + + def __truediv__(self, x): + return self.__div__(x) + + def __div__(self, x): + if isinstance(x, (float, int)): + if approx_equal(x, 0): + raise ZeroDivisionError() + result = Expression(constant=self.constant / x) + for clv, value in self.terms.items(): + result.set_variable(clv, value / x) + else: + if x.is_constant: + result = self / x.constant + else: + return NotImplemented + return result + + def __radd__(self, x): + return self.__add__(x) + + def __add__(self, x): + if isinstance(x, Expression): + result = self.clone() + result.add_expression(x, 1.0) + return result + elif isinstance(x, Variable): + result = self.clone() + result.add_variable(x, 1.0) + return result + elif isinstance(x, (int, float)): + result = self.clone() + result.add_expression(Expression(constant=x), 1.0) + return result + else: + return NotImplemented + + def __rsub__(self, x): + if isinstance(x, Expression): + result = self.clone() + result.multiply(-1.0) + result.add_expression(x, 1.0) + return result + elif isinstance(x, Variable): + result = self.clone() + result.multiply(-1.0) + result.add_variable(x, 1.0) + return result + elif isinstance(x, (int, float)): + result = self.clone() + result.multiply(-1.0) + result.add_expression(Expression(constant=x), 1.0) + return result + else: + return NotImplemented + + def __sub__(self, x): + if isinstance(x, Expression): + result = self.clone() + result.add_expression(x, -1.0) + return result + elif isinstance(x, Variable): + result = self.clone() + result.add_variable(x, -1.0) + return result + elif isinstance(x, (int, float)): + result = self.clone() + result.add_expression(Expression(constant=x), -1.0) + return result + else: + return NotImplemented + + ###################################################################### + # Mathematical operators + ###################################################################### + + __hash__ = object.__hash__ + + def __eq__(self, other): + if isinstance(other, (Expression, Variable, float, int)): + return Constraint(self, Constraint.EQ, other) + else: + return NotImplemented + + def __lt__(self, other): + # < and <= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + return self.__le__(other) + + def __le__(self, other): + if isinstance(other, (Expression, Variable, float, int)): + return Constraint(self, Constraint.LEQ, other) + else: + return NotImplemented + + def __gt__(self, other): + # > and >= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + return self.__ge__(other) + + def __ge__(self, other): + if isinstance(other, (Expression, Variable, float, int)): + return Constraint(self, Constraint.GEQ, other) + else: + return NotImplemented + + ###################################################################### + # Internal mechanisms + ###################################################################### + + def add_expression(self, expr, n=1.0, subject=None, solver=None): + if isinstance(expr, AbstractVariable): + expr = Expression(variable=expr) + + self.constant = self.constant + n * expr.constant + for clv, coeff in expr.terms.items(): + self.add_variable(clv, coeff * n, subject, solver) + + def add_variable(self, v, cd=1.0, subject=None, solver=None): + # print 'expression: add_variable', v, cd + coeff = self.terms.get(v) + if coeff: + new_coefficient = coeff + cd + if approx_equal(new_coefficient, 0.0): + if solver: + solver.note_removed_variable(v, subject) + self.remove_variable(v) + else: + self.set_variable(v, new_coefficient) + else: + if not approx_equal(cd, 0.0): + self.set_variable(v, cd) + if solver: + solver.note_added_variable(v, subject) + + def set_variable(self, v, c): + self.terms[v] = float(c) + + def remove_variable(self, v): + del self.terms[v] + + def any_pivotable_variable(self): + if self.is_constant: + raise InternalError('any_pivotable_variable called on a constant') + + retval = None + for clv, c in self.terms.items(): + if clv.is_pivotable: + retval = clv + break + + return retval + + def substitute_out(self, outvar, expr, subject=None, solver=None): + multiplier = self.terms.pop(outvar) + self.constant = self.constant + multiplier * expr.constant + + for clv, coeff in expr.terms.items(): + old_coefficient = self.terms.get(clv) + if old_coefficient: + new_coefficient = old_coefficient + multiplier * coeff + if approx_equal(new_coefficient, 0): + solver.note_removed_variable(clv, subject) + del self.terms[clv] + else: + self.set_variable(clv, new_coefficient) + else: + self.set_variable(clv, multiplier * coeff) + if solver: + solver.note_added_variable(clv, subject) + + def change_subject(self, old_subject, new_subject): + self.set_variable(old_subject, self.new_subject(new_subject)) + + def multiply(self, x): + self.constant = self.constant * float(x) + for clv, value in self.terms.items(): + self.set_variable(clv, value * x) + + def new_subject(self, subject): + # print "new_subject", subject + value = self.terms.pop(subject) + reciprocal = 1.0 / value + self.multiply(-reciprocal) + return reciprocal + + def coefficient_for(self, clv): + return self.terms.get(clv, 0.0) + + +########################################################################### +# Constraint +# +# Constraints are the restrictions on linear programming; an equality or +# inequality between two expressions. +########################################################################### + +class AbstractConstraint(object): + def __init__(self, strength, weight=1.0): + self.strength = strength + self.weight = weight + self.is_edit_constraint = False + self.is_inequality = False + self.is_stay_constraint = False + + @property + def is_required(self): + return self.strength == REQUIRED + + def __repr__(self): + return '%s:{%s}(%s)' % (repr_strength(self.strength), self.weight, self.expression) + +class EditConstraint(AbstractConstraint): + def __init__(self, variable, strength=STRONG, weight=1.0): + super(EditConstraint, self).__init__(strength, weight) + self.variable = variable + self.expression = Expression(variable, -1.0, variable.value) + self.is_edit_constraint = True + + def __repr__(self): + return 'edit:%s' % super(EditConstraint, self).__repr__() + + +class StayConstraint(AbstractConstraint): + def __init__(self, variable, strength=STRONG, weight=1.0): + super(StayConstraint, self).__init__(strength, weight) + self.variable = variable + self.expression = Expression(variable, -1.0, variable.value) + self.is_stay_constraint=True + + def __repr__(self): + return 'stay:%s' % super(StayConstraint, self).__repr__() + + +class Constraint(AbstractConstraint): + LEQ = -1 + EQ = 0 + GEQ = 1 + + def __init__(self, param1, operator=EQ, param2=None, strength=REQUIRED, weight=1.0): + """Define a new linear constraint. + + param1 may be an expression or variable + param2 may be an expression, variable, or constant, or may be ommitted entirely. + If param2 is specified, the operator must be either LEQ, EQ, or GEQ + """ + if isinstance(param1, Expression): + if param2 is None: + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = param1 + elif isinstance(param2, Expression): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = param1.clone() + if operator == self.LEQ: + self.expression.multiply(-1.0) + self.expression.add_expression(param2, 1.0) + elif operator == self.EQ: + self.expression.add_expression(param2, -1.0) + elif operator == self.GEQ: + self.expression.add_expression(param2, -1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + elif isinstance(param2, Variable): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = param1.clone() + if operator == self.LEQ: + self.expression.multiply(-1.0) + self.expression.add_variable(param2, 1.0) + elif operator == self.EQ: + self.expression.add_variable(param2, -1.0) + elif operator == self.GEQ: + self.expression.add_variable(param2, -1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + + elif isinstance(param2, (float, int)): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = param1.clone() + if operator == self.LEQ: + self.expression.multiply(-1.0) + self.expression.add_expression(Expression(constant=param2), 1.0) + elif operator == self.EQ: + self.expression.add_expression(Expression(constant=param2), -1.0) + elif operator == self.GEQ: + self.expression.add_expression(Expression(constant=param2), -1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + else: + raise InternalError("Invalid parameters to Constraint constructor") + + elif isinstance(param1, Variable): + if param2 is None: + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = Expression(param1) + elif isinstance(param2, Expression): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = param2.clone() + if operator == self.LEQ: + self.expression.add_variable(param1, -1.0) + elif operator == self.EQ: + self.expression.add_variable(param1, -1.0) + elif operator == self.GEQ: + self.expression.multiply(-1.0) + self.expression.add_variable(param1, 1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + + elif isinstance(param2, Variable): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = Expression(param2) + if operator == self.LEQ: + self.expression.add_variable(param1, -1.0) + elif operator == self.EQ: + self.expression.add_variable(param1, -1.0) + elif operator == self.GEQ: + self.expression.multiply(-1.0) + self.expression.add_variable(param1, 1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + + elif isinstance(param2, (float, int)): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = Expression(constant=param2) + if operator == self.LEQ: + self.expression.add_variable(param1, -1.0) + elif operator == self.EQ: + self.expression.add_variable(param1, -1.0) + elif operator == self.GEQ: + self.expression.multiply(-1.0) + self.expression.add_variable(param1, 1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + else: + raise InternalError("Invalid parameters to Constraint constructor") + + elif isinstance(param1, (float, int)): + if param2 is None: + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = Expression(constant=param1) + + elif isinstance(param2, Expression): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = param2.clone() + if operator == self.LEQ: + self.expression.add_expression(Expression(constant=param1), -1.0) + elif operator == self.EQ: + self.expression.add_expression(Expression(constant=param1), -1.0) + elif operator == self.GEQ: + self.expression.multiply(-1.0) + self.expression.add_expression(Expression(constant=param1), 1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + + elif isinstance(param2, Variable): + super(Constraint, self).__init__(strength=strength, weight=weight) + self.expression = Expression(constant=param1) + if operator == self.LEQ: + self.expression.add_variable(param2, -1.0) + elif operator == self.EQ: + self.expression.add_variable(param2, -1.0) + elif operator == self.GEQ: + self.expression.multiply(-1.0) + self.expression.add_variable(param2, 1.0) + else: + raise InternalError("Invalid operator in Constraint constructor") + + elif isinstance(param2, (float, int)): + raise InternalError("Cannot create an inequality between constants") + + else: + raise InternalError("Invalid parameters to Constraint constructor") + else: + raise InternalError("Invalid parameters to Constraint constructor") + + self.is_inequality = operator != self.EQ + + def clone(self): + c = Constraint(self.expression, strength=self.strength, weight=self.weight) + c.is_inequality = self.is_inequality + return c diff --git a/lib/cassowary/expression.rb b/lib/cassowary/expression.rb new file mode 100644 index 00000000..fcbd6811 --- /dev/null +++ b/lib/cassowary/expression.rb @@ -0,0 +1,270 @@ +########################################################################### +# Expressions +# +# Expressions are combinations of variables with multipliers and constants +########################################################################### +module Cassowary + + class Expression + attr_accessor :constant, :terms + + def initialize(variable=nil, value=1.0, constant=0.0) + @constant = constant + @terms = {} + if variable + self.set_variable(variable,value) + end + end + + def set_variable(v, c) + @terms[v] = c.to_f + end + + def remove_variable(v) + @terms.delete(v) + end + + def is_constant + return ! @terms + end + + # Ruby Object.clone _should_ work for Expression.clone() + + + + ###################################################################### + # Mathematical operators + ###################################################################### + def *(x) + result = 0.0 + if x.kind_of? Expression + if self.is_constant + result = x * @constant + elsif x.is_constant + result = self. * x.constant # ? + else + return NotImplemented + end + elsif x.kind_of? Variable + if self.is_constant + result = Expression.new(x, @constant) + else + return NotImplemented + end + elsif x.kind_of? Numeric + result = Expression.new(constant=@constant * x) + @terms.each_pair do |clv, value| + result.set_variable(clv, value * x) + end + else + return NotImplemented + end + return result + end + + def /(x) + result = 0.0 + if x.kind_of? Numeric + if approx_equal(x, 0) + raise ZeroDivisionError + end + result = Expression.new(constant=@constant / x) + terms.each_pair do |clv, value| + result.set_variable(clv, value / x) + end + else + if x.is_constant + result = self / x.constant # don't like 'self' here + else + return NotImplemented + end + end + return result + end + + def +(x) + if x.kind_of? Expression + result = self.clone + result.add_expression(x, 1.0) + return result + elsif x.kind_of? Variable + result = self.clone + result.add_variable(x, 1.0) + return result + elsif x.kind_of? Numeric + result = self.clone + result.add_expression(Expression.new(constant=x), 1.0) + return result + else + return NotImplemented + end + end + + def -(x) + case x.kind_of? + when Expression + result = self.clone + result.add_expression(x, -1.0) + return result + when Variable + result = self.clone + result.add_variable(x, -1.0) + return result + when Numeric + result = self.clone + result.add_expression(Expression.new(constant = x), -1.0) + return result + else + return NotImplemented + end + end + + ###################################################################### + # Mathematical operators + ###################################################################### + def ==(other) # TODO: can't use = in Ruby + case other.kind_of? + when Expression,Variable,Numeric + return Constraint(self, Constraint.EQ, other) + else + return NotImplemented + end + end + + def <(other) + # < and <= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + case other.kind_of? + when Expression,Variable,Numeric + return Constraint(self, Constraint.LEQ, other) + else + return NotImplemented + end + end + + def <=(other) + case other.kind_of? + when Expression,Variable,Numeric + return Constraint(self, Constraint.LEQ, other) + else + return NotImplemented + end + end + + def >(other) + # < and <= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + case other.kind_of? + when Expression,Variable,Numeric + return Constraint(self, Constraint.GEQ, other) + else + return NotImplemented + end + end + + def >=(other) + case other.kind_of? + when Expression,Variable,Numeric + return Constraint(self, Constraint.GEQ, other) + else + return NotImplemented + end + end + + ###################################################################### + # Internal mechanisms + ###################################################################### + def add_expression(expr, n=1.0, subject=nil, solver=nil) + if expr.kind_of? AbstractVariable + expr = Expression.new(variable=expr) + end + @constant = @constant + n * expr.constant + @terms.each_pair do |clv, coeff| + self.add_variable(clv, coeff * n, subject, solver) + end + end + + def add_variable(v, cd=1.0, subject=nil, solver=nil) + coeff = @terms[v] + if coeff + new_coefficient = coeff + cd + if approx_equal(new_coefficient, 0.0) + if solver + solver.note_removed_variable(v, subject) + end + self.remove_variable(v) + else + self.set_variable(v, new_coefficient) + end + else + if ! approx_equal(cd, 0.0) + self.set_variable(v, cd) + if solver + solver.note_added_variable(v, subject) + end + end + end + end + + def any_pivotable_variable + if self.is_constant + raise InternalError('any_pivotable_variable called on a constant') + end + retval = nil + @terms.each_pair do |clv, c| + if clv.is_pivotable + retval = clv + break + end + end + return retval + end + + def substitute_out(outvar, expr, subject=nil, solver=nil) + multiplier = @terms.delete(outvar) + @constant = @constant + multiplier * expr.constant + expr.terms.each_pair do |clv, coeff| + old_coefficient = @terms[clv] + if old_coefficient + new_coefficient = old_coefficient + multiplier * coeff + if approx_equal(new_coefficient, 0) + solver.note_removed_variable(clv, subject) + @terms.delete(cv) + else + self.set_variable(clv,new_coefficient) + end + else + self.set_variable(clv, multiplier * coeff) + if solver + solver.note_added_variable(clv, subject) + end + end + end + end + + def change_subject(old_subject, new_subject) + self.set_variable(old_subject, self.new_subject(new_subject)) + end + + def multiply(x) + @constant = @constant * x.to_f + @terms.each_pair do |clv, value| + self.set_variable(clv, value * x) + end + end + + def new_subject(subject) + value = @terms.delete(subject) + reciprocal = 1.0 / value + self.multiply(-reciprocal) + return reciprocal + end + + def coefficient_for(clv) + return terms[clv] || 0.0 + end + end # class Expression + +end # Cassowary module diff --git a/lib/cassowary/simplex_solver.py b/lib/cassowary/simplex_solver.py new file mode 100644 index 00000000..2ab10317 --- /dev/null +++ b/lib/cassowary/simplex_solver.py @@ -0,0 +1,612 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +from .edit_info import EditInfo +from .error import RequiredFailure, ConstraintNotFound, InternalError +from .expression import Expression, StayConstraint, EditConstraint, ObjectiveVariable, SlackVariable, DummyVariable +from .tableau import Tableau +from .utils import approx_equal, EPSILON, STRONG, WEAK + + +class SolverEditContext(object): + def __init__(self, solver): + self.solver = solver + + def __enter__(self): + self.solver.begin_edit() + + def __exit__(self, type, value, tb): + self.solver.end_edit() + + +class SimplexSolver(Tableau): + def __init__(self): + super(SimplexSolver, self).__init__() + + self.stay_error_vars = [] + + self.error_vars = {} + self.marker_vars = {} + + self.objective = ObjectiveVariable('Z') + self.edit_var_map = {} + + self.slack_counter = 0 + self.artificial_counter = 0 + self.dummy_counter = 0 + self.auto_solve = True + self.needs_solving = False + + self.optimize_count = 0 + + self.rows[self.objective] = Expression() + self.edit_variable_stack = [0] + + def __repr__(self): + parts = [] + parts.append('stay_error_vars: %s' % self.stay_error_vars) + parts.append('edit_var_map: %s' % self.edit_var_map) + return super(SimplexSolver, self).__repr__() + '\n' + '\n'.join(parts) + + + def add_constraint(self, cn, strength=None, weight=None): + if strength or weight: + cn = cn.clone() + if strength: + cn.strength = strength + if weight: + cn.weight = weight + + # print('add_constraint', cn) + expr, eplus, eminus, prev_edit_constant = self.new_expression(cn) + + if not self.try_adding_directly(expr): + self.add_with_artificial_variable(expr) + + self.needs_solving = True + + if cn.is_edit_constraint: + i = len(self.edit_var_map) + + self.edit_var_map[cn.variable] = EditInfo(cn, eplus, eminus, prev_edit_constant, i) + + if self.auto_solve: + self.optimize(self.objective) + self.set_external_variables() + + return cn + + def add_edit_var(self, v, strength=STRONG): + # print("add_edit_var", v, strength) + return self.add_constraint(EditConstraint(v, strength)) + + def remove_edit_var(self, v): + self.remove_constraint(self.edit_var_map[v].constraint) + + def edit(self): + return SolverEditContext(self) + + def resolve(self): + self.dual_optimize() + self.set_external_variables() + self.infeasible_rows.clear() + self.reset_stay_constants() + + ####################################################################### + # Internals + ####################################################################### + + def new_expression(self, cn): + # print("* new_expression", cn) + # print("cn.is_inequality == ", cn.is_inequality) + # print("cn.is_required == ", cn.is_required) + expr = Expression(constant=cn.expression.constant) + eplus = None + eminus = None + prev_edit_constant = None + for v, c in cn.expression.terms.items(): + e = self.rows.get(v) + if not e: + expr.add_variable(v, c) + else: + expr.add_expression(e, c) + + if cn.is_inequality: + # print("Inequality, adding slack") + self.slack_counter = self.slack_counter + 1 + slack_var = SlackVariable(prefix='s', number=self.slack_counter) + expr.set_variable(slack_var, -1) + + self.marker_vars[cn] = slack_var + if not cn.is_required: + self.slack_counter = self.slack_counter + 1 + eminus = SlackVariable(prefix='em', number=self.slack_counter) + expr.set_variable(eminus, 1) + z_row = self.rows[self.objective] + z_row.set_variable(eminus, cn.strength * cn.weight) + self.insert_error_var(cn, eminus) + self.note_added_variable(eminus, self.objective) + else: + if cn.is_required: + # print("Equality, required") + self.dummy_counter = self.dummy_counter + 1 + dummy_var = DummyVariable(number=self.dummy_counter) + eplus = dummy_var + eminus = dummy_var + prev_edit_constant = cn.expression.constant + expr.set_variable(dummy_var, 1) + self.marker_vars[cn] = dummy_var + # print("Adding dummy_var == d%s" % self.dummy_counter) + else: + # print("Equality, not required") + self.slack_counter = self.slack_counter + 1 + eplus = SlackVariable(prefix='ep', number=self.slack_counter) + eminus = SlackVariable(prefix='em', number=self.slack_counter) + expr.set_variable(eplus, -1) + expr.set_variable(eminus, 1) + self.marker_vars[cn] = eplus + + z_row = self.rows[self.objective] + # print("z_row", z_row) + sw_coeff = cn.strength * cn.weight + # if sw_coeff == 0: + # print("cn ==", cn) + # print("adding ", eplus, "and", eminus, "with sw_coeff", sw_coeff) + z_row.set_variable(eplus, sw_coeff) + self.note_added_variable(eplus, self.objective) + z_row.set_variable(eminus, sw_coeff) + self.note_added_variable(eminus, self.objective) + + self.insert_error_var(cn, eminus) + self.insert_error_var(cn, eplus) + + if cn.is_stay_constraint: + self.stay_error_vars.append((eplus, eminus)) + elif cn.is_edit_constraint: + prev_edit_constant = cn.expression.constant + + # print('new_expression returning:', expr) + if expr.constant < 0: + expr.multiply(-1.0) + return expr, eplus, eminus, prev_edit_constant + + def begin_edit(self): + assert len(self.edit_var_map) > 0 + self.infeasible_rows.clear() + self.reset_stay_constants() + self.edit_variable_stack.append(len(self.edit_var_map)) + + def end_edit(self): + assert len(self.edit_var_map) > 0 + self.resolve() + self.edit_variable_stack.pop() + self.remove_edit_vars_to(self.edit_variable_stack[-1]) + + def remove_all_edit_vars(self): + self.remove_edit_vars_to(0) + + def remove_edit_vars_to(self, n): + try: + removals = [] + for v, cei in self.edit_var_map.items(): + if cei.index >= n: + removals.append(v) + + for v in removals: + self.remove_edit_var(v) + + assert len(self.edit_var_map) == n + + except ConstraintNotFound: + raise InternalError('Constraint not found during internal removal') + + def add_stay(self, v, strength=WEAK, weight=1.0): + return self.add_constraint(StayConstraint(v, strength, weight)) + + def remove_constraint(self, cn): + # print("removeConstraint", cn) + # print(self) + self.needs_solving = True + self.reset_stay_constants() + z_row = self.rows[self.objective] + + e_vars = self.error_vars.get(cn) + # print("e_vars ==", e_vars) + if e_vars: + for cv in e_vars: + try: + z_row.add_expression(self.rows[cv], -cn.weight * cn.strength, self.objective, self) + # print('add expression', self.rows[cv]) + except KeyError: + z_row.add_variable(cv, -cn.weight * cn.strength, self.objective, self) + # print('add variable', cv) + + try: + marker = self.marker_vars.pop(cn) + except KeyError: + raise ConstraintNotFound() + + # print("Looking to remove var", marker) + if not self.rows.get(marker): + col = self.columns[marker] + # print("Must pivot -- columns are", col) + exit_var = None + min_ratio = 0.0 + for v in col: + # print('check var', v) + if v.is_restricted: + # print('var', v, ' is restricted') + expr = self.rows[v] + coeff = expr.coefficient_for(marker) + # print("Marker", marker, "'s coefficient in", expr, "is", coeff) + if coeff < 0: + r = -expr.constant / coeff + if exit_var is None or r < min_ratio: # EXTRA BITS IN JS? + # print('set exit var = ',v,r) + min_ratio = r + exit_var = v + + if exit_var is None: + # print("exit_var is still None") + for v in col: + # print('check var', v) + if v.is_restricted: + # print('var', v, ' is restricted') + expr = self.rows[v] + coeff = expr.coefficient_for(marker) + # print("Marker", marker, "'s coefficient in", expr, "is", coeff) + r = expr.constant / coeff + if exit_var is None or r < min_ratio: + # print('set exit var = ',v,r) + min_ratio = r + exit_var = v + + if exit_var is None: + # print("exit_var is still None (again)") + if len(col) == 0: + # print('remove column',marker) + self.remove_column(marker) + else: + exit_var = [v for v in col if v != self.objective][-1] # ?? + # print('set exit var', exit_var) + + if exit_var is not None: + # print('Pivot', marker, exit_var,) + self.pivot(marker, exit_var) + + if self.rows.get(marker): + # print('remove row', marker) + expr = self.remove_row(marker) + + if e_vars: + # print('e_vars exist') + for v in e_vars: + if v != marker: + # print('remove column',v) + self.remove_column(v) + + if cn.is_stay_constraint: + if e_vars: + # for p_evar, m_evar in self.stay_error_vars: + remaining = [] + while self.stay_error_vars: + p_evar, m_evar = self.stay_error_vars.pop() + found = False + try: + # print('stay constraint - remove plus evar', p_evar) + e_vars.remove(p_evar) + found = True + except KeyError: + pass + try: + # print('stay constraint - remove minus evar', m_evar) + e_vars.remove(m_evar) + found = True + except KeyError: + pass + if not found: + remaining.append((p_evar, m_evar)) + self.stay_error_vars = remaining + + elif cn.is_edit_constraint: + assert e_vars is not None + # print('edit constraint - remove column', self.edit_var_map[cn.variable].edit_minus) + self.remove_column(self.edit_var_map[cn.variable].edit_minus) + del self.edit_var_map[cn.variable] + + if e_vars: + for e_var in e_vars: + # print('Remove error var', e_var) + del self.error_vars[e_var] + + if self.auto_solve: + # print('final auto solve') + self.optimize(self.objective) + self.set_external_variables() + + def resolve_array(self, new_edit_constants): + for v, cei in self.edit_var_map.items(): + self.suggest_value(v, new_edit_constants[cei.index]) + + self.resolve() + + def suggest_value(self, v, x): + cei = self.edit_var_map.get(v) + if not cei: + raise InternalError("suggestValue for variable %s, but var is not an edit variable" % v) + # print(cei) + delta = x - cei.prev_edit_constant + cei.prev_edit_constant = x + self.delta_edit_constant(delta, cei.edit_plus, cei.edit_minus) + + def solve(self): + if self.needs_solving: + self.optimize(self.objective) + self.set_external_variables() + + def set_edited_value(self, v, n): + if v not in self.columns or v not in self.rows: + v.value = n + + if not approx_equal(n, v.value): + self.add_edit_var(v) + self.begin_edit() + + self.suggest_value(v, n) + + self.end_edit() + + def add_var(self, v): + if v not in self.columns or v not in self.rows: + self.add_stay(v) + + def add_with_artificial_variable(self, expr): + # print("add_with_artificial_variable", expr) + self.artificial_counter = self.artificial_counter + 1 + av = SlackVariable(prefix='a', number=self.artificial_counter) + az = ObjectiveVariable('az') + az_row = expr.clone() + # print('Before add_rows') + # print(self) + self.add_row(az, az_row) + self.add_row(av, expr) + # print('after add_rows') + # print(self) + self.optimize(az) + az_tableau_row = self.rows[az] + # print("azTableauRow.constant =", az_tableau_row.constant) + if not approx_equal(az_tableau_row.constant, 0.0): + # print("azTableauRow.constant is 0") + self.remove_row(az) + self.remove_column(av) + raise RequiredFailure() + + e = self.rows.get(av) + if e is not None: + # print("av exists") + if e.is_constant: + # print("av is constant") + self.remove_row(av) + self.remove_row(az) + return + entry_var = e.any_pivotable_variable() + self.pivot(entry_var, av) + + # print("av shouldn't exist now") + assert av not in self.rows + self.remove_column(av) + self.remove_row(az) + + def try_adding_directly(self, expr): + # print("try_adding_directly", expr) + subject = self.choose_subject(expr) + if subject is None: + # print("try_adding_directly returning: False") + return False + + expr.new_subject(subject) + if subject in self.columns: + self.substitute_out(subject, expr) + + self.add_row(subject, expr) + # print("try_adding_directly returning: True") + return True + + def choose_subject(self, expr): + # print('choose_subject', expr) + subject = None + found_unrestricted = False + found_new_restricted = False + + retval_found = False + retval = None + for v, c in expr.terms.items(): # CHECK?? + if found_unrestricted: + if not v.is_restricted: + if v not in self.columns: + retval_found = True + retval = v + break + else: + if v.is_restricted: + if not found_new_restricted and not v.is_dummy and c < 0: + col = self.columns.get(v) + if col == None or (len(col) == 1 and self.objective in self.columns): + subject = v + found_new_restricted = True + else: + subject = v + found_unrestricted = True + + if retval_found: + return retval + + if subject: + return subject + + coeff = 0.0 + for v, c in expr.terms.items(): + if not v.is_dummy: + retval_found = True + retval = None + break + if not v in self.columns: + subject = v + coeff = c + + if retval_found: + return retval + + if not approx_equal(expr.constant, 0.0): + raise RequiredFailure() + + if coeff > 0: + expr = expr * -1 + + return subject + + def delta_edit_constant(self, delta, plus_error_var, minus_error_var): + expr_plus = self.rows.get(plus_error_var) + if expr_plus is not None: + expr_plus.constant = expr_plus.constant + delta + if expr_plus.constant < 0.0: + self.infeasible_rows.add(plus_error_var) + return + + expr_minus = self.rows.get(minus_error_var) + if expr_minus is not None: + expr_minus.constant = expr_minus.constant - delta + if expr_minus.constant < 0: + self.infeasible_rows.add(minus_error_var) + return + + try: + for basic_var in self.columns[minus_error_var]: + expr = self.rows[basic_var] + c = expr.coefficient_for(minus_error_var) + expr.constant = expr.constant + (c * delta) + if basic_var.is_restricted and expr.constant < 0: + self.infeasible_rows.add(basic_var) + except KeyError: + pass + + def dual_optimize(self): + z_row = self.rows.get(self.objective) + while self.infeasible_rows: + exit_var = self.infeasible_rows.pop() + entry_var = None + expr = self.rows.get(exit_var) + if expr: + if expr.constant < 0: + ratio = float('inf') + for v, cd in expr.terms.items(): + if cd > 0 and v.is_pivotable: + zc = z_row.coefficient_for(v) + r = zc / cd + if r < ratio: # JS difference? + entry_var = v + ratio = r + if ratio == float('inf'): + raise InternalError("ratio == nil (MAX_VALUE) in dual_optimize") + self.pivot(entry_var, exit_var) + + def optimize(self, z_var): + # print("optimize", z_var) + # print(self) + self.optimize_count = self.optimize_count + 1 + + z_row = self.rows[z_var] + entry_var = None + exit_var = None + + # print(self.objective) + # print(z_var) + # print(self.rows[self.objective]) + # print(self.rows[z_var]) + + while True: + objective_coeff = 0.0 + + # Not convinced the sort is correct here; but test suite + # doesn't pass reliably without it. + for v, c in sorted(z_row.terms.items(), key=lambda x: x[0].name): + # print('term check', v, v.is_pivotable, c) + if v.is_pivotable and c < objective_coeff: + # print('candidate found') + objective_coeff = c + entry_var = v + break; + + if objective_coeff >= -EPSILON or entry_var is None: + return + + # print('entry_var:', entry_var) + # print("objective_coeff:", objective_coeff) + + min_ratio = float('inf') + r = 0 + + for v in self.columns[entry_var]: + # print("checking", v) + if v.is_pivotable: + expr = self.rows[v] + coeff = expr.coefficient_for(entry_var) + # print('pivotable, coeff =', coeff) + if coeff < 0: + r = -expr.constant / coeff + if r < min_ratio: + min_ratio = r + exit_var = v + + if min_ratio == float('inf'): + raise RequiredFailure('Objective function is unbounded') + + self.pivot(entry_var, exit_var) + + # print(self) + + def pivot(self, entry_var, exit_var): + # print('pivot:',entry_var, exit_var) + if entry_var is None: + print("WARN - entry_var is None") + if exit_var is None: + print("WARN - exit_var is None") + + p_expr = self.remove_row(exit_var) + p_expr.change_subject(exit_var, entry_var) + self.substitute_out(entry_var, p_expr) + self.add_row(entry_var, p_expr) + + def reset_stay_constants(self): + # print("reset_stay_constants") + for p_var, m_var in self.stay_error_vars: + expr = self.rows.get(p_var) + if expr is None: + expr = self.rows.get(m_var) + if expr: + expr.constant = 0.0 + + def set_external_variables(self): + # print("set_external_variables") + # print(self) + for v in self.external_parametric_vars: + if self.rows.get(v): + # print("Variable %s in external_parametric_vars is basic" % v) + continue + v.value = 0.0 + + for v in self.external_rows: + expr = self.rows[v] + v.value = expr.constant + + self.needs_solving = False + + def insert_error_var(self, cn, var): + # print('insert_error_var', cn, var) + constraint_set = self.error_vars.get(var) + if not constraint_set: + constraint_set = set() + self.error_vars[cn] = constraint_set + + constraint_set.add(var) + + self.error_vars.setdefault(var, set()).add(var) diff --git a/lib/cassowary/simplex_solver.rb b/lib/cassowary/simplex_solver.rb new file mode 100644 index 00000000..26b9a1b9 --- /dev/null +++ b/lib/cassowary/simplex_solver.rb @@ -0,0 +1,769 @@ +module Cassowary + require_relative 'utils' + require_relative 'tableau' + require_relative 'variable' + require_relative 'expression' + require_relative 'constraint' + + class SolverEditContext + attr_accessor :solver + + def initialize(solver) + @solver = solver + end + + def enter + @solver.begin_edit + end + + def exit(type, value, tb) + @solver.end_edit + end + end # class SolverEditContext + + class SimplexSolver < Tableau + attr_accessor :stay_error_vars, :error_vars, :marker_vars, :objective, + :edit_var_map, :slack_counter, :artificial_counter, :dummy_counter, + :auto_solve, :need_solving, :optimize_count, :rows, :edit_variable_stack + + def initialize() + super() + @stay_error_vars = [] + @error_vars = {} + @marker_vars = {} + + @objective = ObjectiveVariable.new('Z') + @edit_var_map = {} + + @slack_counter = 0 + @artificial_counter = 0 + @dummy_counter = 0 + @auto_solve = true + @needs_solving = false + + @optimize_count = 0 + + @rows[self.objective] = Expression() + @edit_variable_stack = [0] + end +=begin + def __repr__(self): + parts = [] + parts.append('stay_error_vars: %s' % self.stay_error_vars) + parts.append('edit_var_map: %s' % self.edit_var_map) + return super(SimplexSolver, self).__repr__() + '\n' + '\n'.join(parts) + + + def add_constraint(self, cn, strength=None, weight=None): + if strength or weight: + cn = cn.clone() + if strength: + cn.strength = strength + if weight: + cn.weight = weight + + # print('add_constraint', cn) + expr, eplus, eminus, prev_edit_constant = self.new_expression(cn) + + if not self.try_adding_directly(expr): + self.add_with_artificial_variable(expr) + + self.needs_solving = True + + if cn.is_edit_constraint: + i = len(self.edit_var_map) + + self.edit_var_map[cn.variable] = EditInfo(cn, eplus, eminus, prev_edit_constant, i) + + if self.auto_solve: + self.optimize(self.objective) + self.set_external_variables() + + return cn +=end + def add_constraint(cn, strength=nil, weight=nil) + if strength || weight + cn = cn.clone + if strength + cn.strength = strength + end + if weight + cn.weight + end + end + expr, eplus, eminus, prev_edit_constant = self.new_expression(cn) # worry? + if ! self.try_adding_directly(expr) + self.add_with_artificial_variable(expr) + end + @needs_solving = true + if cn.is_edit_constraint + i = @edit_var_map.size + @edit_var_map[cn.variable] = EditInfo.new(cn, eplus, eminus, prev_edit_constant, i) + end + if @auto_solve + self.optimize(@objective) + self.set_external_variables() + end + end +=begin + def add_edit_var(self, v, strength=STRONG): + # print("add_edit_var", v, strength) + return self.add_constraint(EditConstraint(v, strength)) + + def remove_edit_var(self, v): + self.remove_constraint(self.edit_var_map[v].constraint) + + def edit(self): + return SolverEditContext(self) + + def resolve(self): + self.dual_optimize() + self.set_external_variables() + self.infeasible_rows.clear() + self.reset_stay_constants() +=end + def add_edit_var(v, strength=STRONG) + rreturn self.add_constraint(EditConstraint.new(v, strength)) + end + + def remove_edit_var(v) + self.remove_constraint(@edit_var_map[v].constraint) + end + + def edit() + return SolverEditContext.new(self) + end + + def resolve() + self.dual_optimize() + self.set_external_variables() + self.infeasible_rows.clear() + self.reset_stay_constants() + end + +=begin + + ####################################################################### + # Internals + ####################################################################### + + def new_expression(self, cn): + # print("* new_expression", cn) + # print("cn.is_inequality == ", cn.is_inequality) + # print("cn.is_required == ", cn.is_required) + expr = Expression(constant=cn.expression.constant) + eplus = None + eminus = None + prev_edit_constant = None + for v, c in cn.expression.terms.items(): + e = self.rows.get(v) + if not e: + expr.add_variable(v, c) + else: + expr.add_expression(e, c) + + if cn.is_inequality: + # print("Inequality, adding slack") + self.slack_counter = self.slack_counter + 1 + slack_var = SlackVariable(prefix='s', number=self.slack_counter) + expr.set_variable(slack_var, -1) + + self.marker_vars[cn] = slack_var + if not cn.is_required: + self.slack_counter = self.slack_counter + 1 + eminus = SlackVariable(prefix='em', number=self.slack_counter) + expr.set_variable(eminus, 1) + z_row = self.rows[self.objective] + z_row.set_variable(eminus, cn.strength * cn.weight) + self.insert_error_var(cn, eminus) + self.note_added_variable(eminus, self.objective) + else: + if cn.is_required: + # print("Equality, required") + self.dummy_counter = self.dummy_counter + 1 + dummy_var = DummyVariable(number=self.dummy_counter) + eplus = dummy_var + eminus = dummy_var + prev_edit_constant = cn.expression.constant + expr.set_variable(dummy_var, 1) + self.marker_vars[cn] = dummy_var + # print("Adding dummy_var == d%s" % self.dummy_counter) + else: + # print("Equality, not required") + self.slack_counter = self.slack_counter + 1 + eplus = SlackVariable(prefix='ep', number=self.slack_counter) + eminus = SlackVariable(prefix='em', number=self.slack_counter) + expr.set_variable(eplus, -1) + expr.set_variable(eminus, 1) + self.marker_vars[cn] = eplus + + z_row = self.rows[self.objective] + # print("z_row", z_row) + sw_coeff = cn.strength * cn.weight + # if sw_coeff == 0: + # print("cn ==", cn) + # print("adding ", eplus, "and", eminus, "with sw_coeff", sw_coeff) + z_row.set_variable(eplus, sw_coeff) + self.note_added_variable(eplus, self.objective) + z_row.set_variable(eminus, sw_coeff) + self.note_added_variable(eminus, self.objective) + + self.insert_error_var(cn, eminus) + self.insert_error_var(cn, eplus) + + if cn.is_stay_constraint: + self.stay_error_vars.append((eplus, eminus)) + elif cn.is_edit_constraint: + prev_edit_constant = cn.expression.constant + + # print('new_expression returning:', expr) + if expr.constant < 0: + expr.multiply(-1.0) + return expr, eplus, eminus, prev_edit_constant +=end + +=begin + def begin_edit(self): + assert len(self.edit_var_map) > 0 + self.infeasible_rows.clear() + self.reset_stay_constants() + self.edit_variable_stack.append(len(self.edit_var_map)) + + def end_edit(self): + assert len(self.edit_var_map) > 0 + self.resolve() + self.edit_variable_stack.pop() + self.remove_edit_vars_to(self.edit_variable_stack[-1]) + + def remove_all_edit_vars(self): + self.remove_edit_vars_to(0) +=end + + def begin_edit() + #assert len(self.edit_var_map) > 0 + self.infeasible_rows.clear() + self.reset_stay_constants() + self.edit_variable_stack.append(@edit_var_map.size) + end + + def end_edit() + #assert len(self.edit_var_map) > 0 + self.resolve() + self.edit_variable_stack.pop() + self.remove_edit_vars_to(@edit_variable_stack[-1]) # worry? + end + + def remove_all_edit_vars() + self.remove_edit_vars_to(0) + end + +=begin + def remove_edit_vars_to(self, n): + try: + removals = [] + for v, cei in self.edit_var_map.items(): + if cei.index >= n: + removals.append(v) + + for v in removals: + self.remove_edit_var(v) + + assert len(self.edit_var_map) == n + + except ConstraintNotFound: + raise InternalError('Constraint not found during internal removal') + + def add_stay(self, v, strength=WEAK, weight=1.0): + return self.add_constraint(StayConstraint(v, strength, weight)) +=end + def remove_edit_vars_to(n) + begin + + rescue ConstraintNotFound => err + raise InternalError('Constraint not found during internal removal') + end + end + + def add_stay(v, strength=WEAK, weight=1.0) + return self.add_constraint(StayConstraint.new(v, strength, weight)) + end +=begin + def remove_constraint(self, cn): + # print("removeConstraint", cn) + # print(self) + self.needs_solving = True + self.reset_stay_constants() + z_row = self.rows[self.objective] + + e_vars = self.error_vars.get(cn) + # print("e_vars ==", e_vars) + if e_vars: + for cv in e_vars: + try: + z_row.add_expression(self.rows[cv], -cn.weight * cn.strength, self.objective, self) + # print('add expression', self.rows[cv]) + except KeyError: + z_row.add_variable(cv, -cn.weight * cn.strength, self.objective, self) + # print('add variable', cv) + + try: + marker = self.marker_vars.pop(cn) + except KeyError: + raise ConstraintNotFound() + + # print("Looking to remove var", marker) + if not self.rows.get(marker): + col = self.columns[marker] + # print("Must pivot -- columns are", col) + exit_var = None + min_ratio = 0.0 + for v in col: + # print('check var', v) + if v.is_restricted: + # print('var', v, ' is restricted') + expr = self.rows[v] + coeff = expr.coefficient_for(marker) + # print("Marker", marker, "'s coefficient in", expr, "is", coeff) + if coeff < 0: + r = -expr.constant / coeff + if exit_var is None or r < min_ratio: # EXTRA BITS IN JS? + # print('set exit var = ',v,r) + min_ratio = r + exit_var = v + + if exit_var is None: + # print("exit_var is still None") + for v in col: + # print('check var', v) + if v.is_restricted: + # print('var', v, ' is restricted') + expr = self.rows[v] + coeff = expr.coefficient_for(marker) + # print("Marker", marker, "'s coefficient in", expr, "is", coeff) + r = expr.constant / coeff + if exit_var is None or r < min_ratio: + # print('set exit var = ',v,r) + min_ratio = r + exit_var = v + + if exit_var is None: + # print("exit_var is still None (again)") + if len(col) == 0: + # print('remove column',marker) + self.remove_column(marker) + else: + exit_var = [v for v in col if v != self.objective][-1] # ?? + # print('set exit var', exit_var) + + if exit_var is not None: + # print('Pivot', marker, exit_var,) + self.pivot(marker, exit_var) + + if self.rows.get(marker): + # print('remove row', marker) + expr = self.remove_row(marker) + + if e_vars: + # print('e_vars exist') + for v in e_vars: + if v != marker: + # print('remove column',v) + self.remove_column(v) + + if cn.is_stay_constraint: + if e_vars: + # for p_evar, m_evar in self.stay_error_vars: + remaining = [] + while self.stay_error_vars: + p_evar, m_evar = self.stay_error_vars.pop() + found = False + try: + # print('stay constraint - remove plus evar', p_evar) + e_vars.remove(p_evar) + found = True + except KeyError: + pass + try: + # print('stay constraint - remove minus evar', m_evar) + e_vars.remove(m_evar) + found = True + except KeyError: + pass + if not found: + remaining.append((p_evar, m_evar)) + self.stay_error_vars = remaining + + elif cn.is_edit_constraint: + assert e_vars is not None + # print('edit constraint - remove column', self.edit_var_map[cn.variable].edit_minus) + self.remove_column(self.edit_var_map[cn.variable].edit_minus) + del self.edit_var_map[cn.variable] + + if e_vars: + for e_var in e_vars: + # print('Remove error var', e_var) + del self.error_vars[e_var] + + if self.auto_solve: + # print('final auto solve') + self.optimize(self.objective) + self.set_external_variables() +=end +=begin + def resolve_array(self, new_edit_constants): + for v, cei in self.edit_var_map.items(): + self.suggest_value(v, new_edit_constants[cei.index]) + + self.resolve() + + def suggest_value(self, v, x): + cei = self.edit_var_map.get(v) + if not cei: + raise InternalError("suggestValue for variable %s, but var is not an edit variable" % v) + # print(cei) + delta = x - cei.prev_edit_constant + cei.prev_edit_constant = x + self.delta_edit_constant(delta, cei.edit_plus, cei.edit_minus) + + def solve(self): + if self.needs_solving: + self.optimize(self.objective) + self.set_external_variables() + + def set_edited_value(self, v, n): + if v not in self.columns or v not in self.rows: + v.value = n + + if not approx_equal(n, v.value): + self.add_edit_var(v) + self.begin_edit() + + self.suggest_value(v, n) + + self.end_edit() + + def add_var(self, v): + if v not in self.columns or v not in self.rows: + self.add_stay(v) +=end + def resolve_array(new_edit_constants) + @edit_var_map.each_pair do |v, cei| + self.suggest_value(v, new_edit_constants[cei.index]) # TODO class/type of cei is? + end + self.resolve + end + + def suggest_value(v, x) + cei = @edit_var_map[v] + if ! cei + raise InternalError("suggestValue for variable %s, but var is not an edit variable", v) #TODO + end + delta = x - cei.prev_edit_constant + cei.prev_edit_constant = x + self.delta_edit_constant(delta, cei.edit_plus, cei.edit_minus) + end + + def solve() + if @needs_solving + self.optimize(@objective) + self.set_external_variables() + end + end + + def set_edited_value(v, n) + #if v not in self.columns or v not in self.rows + if ! @columns[v] || !@rows[v] # TODO correct? + v.value = n + end + if ! approx_equal(n, v.value) + self.add_edit_var(v) + self.begin_edit() + self.suggest_value(v, n) + self.end_edit() + end + end + + def add_var(v) + # if v not in self.columns or v not in self.rows: + if ! @columns[v] || !@rows[v] # TODO correct? + self.add_stay(v) + end + end + +=begin + + def add_with_artificial_variable(self, expr): + # print("add_with_artificial_variable", expr) + self.artificial_counter = self.artificial_counter + 1 + av = SlackVariable(prefix='a', number=self.artificial_counter) + az = ObjectiveVariable('az') + az_row = expr.clone() + # print('Before add_rows') + # print(self) + self.add_row(az, az_row) + self.add_row(av, expr) + # print('after add_rows') + # print(self) + self.optimize(az) + az_tableau_row = self.rows[az] + # print("azTableauRow.constant =", az_tableau_row.constant) + if not approx_equal(az_tableau_row.constant, 0.0): + # print("azTableauRow.constant is 0") + self.remove_row(az) + self.remove_column(av) + raise RequiredFailure() + + e = self.rows.get(av) + if e is not None: + # print("av exists") + if e.is_constant: + # print("av is constant") + self.remove_row(av) + self.remove_row(az) + return + entry_var = e.any_pivotable_variable() + self.pivot(entry_var, av) + + # print("av shouldn't exist now") + assert av not in self.rows + self.remove_column(av) + self.remove_row(az) + +=end + def try_adding_directly(expr) + subject = self.choose_subject(expr) + if subject == nil + # print("try_adding_directly returning: False") + return false + end + expr.new_subject(subject) + if @columns[subject] + self.substitute_out(subject, expr) + end + self.add_row(subject, expr) + # print("try_adding_directly returning: True") + return True + end +=begin + def choose_subject(self, expr): + # print('choose_subject', expr) + subject = None + found_unrestricted = False + found_new_restricted = False + + retval_found = False + retval = None + for v, c in expr.terms.items(): # CHECK?? + if found_unrestricted: + if not v.is_restricted: + if v not in self.columns: + retval_found = True + retval = v + break + else: + if v.is_restricted: + if not found_new_restricted and not v.is_dummy and c < 0: + col = self.columns.get(v) + if col == None or (len(col) == 1 and self.objective in self.columns): + subject = v + found_new_restricted = True + else: + subject = v + found_unrestricted = True + + if retval_found: + return retval + + if subject: + return subject + + coeff = 0.0 + for v, c in expr.terms.items(): + if not v.is_dummy: + retval_found = True + retval = None + break + if not v in self.columns: + subject = v + coeff = c + + if retval_found: + return retval + + if not approx_equal(expr.constant, 0.0): + raise RequiredFailure() + + if coeff > 0: + expr = expr * -1 + + return subject +=end + def delta_edit_constant(delta, plus_error_var, minus_error_var) + expr_plus = @rows[plus_error_var] + if expr_plus + expr_plus.constant = expr_plus.constant + delta + if expr_plus.constant < 0.0 + @infeasible_rows.add(plus_error_var) + end + return + end + expr_minus = @rows.get[minus_error_var] + if expr_minus + expr_minus.constant = expr_minus.constant - delta + if expr_minus.constant < 0 + @infeasible_rows.add(minus_error_var) + end + return + end + begin + @columns[minus_error_var].each do |basic_var| + expr = @rows[basic_var] + c = expr.coefficient_for(minus_error_var) + expr.constant = expr.constant + (c * delta) + if basic_var.is_restricted && expr.constant < 0 + @infeasible_rows.add(basic_var) + end + end + rescue KeyError => e # Wrong! + pass + end + end + + def dual_optimize() + z_row = @rows[@objective] + @infeasible_rows.each do |t| # a Set + exit_var = @infeasible_rows.delete(t) + entry_var = None + expr = @rows[exit_var] + if expr + if expr.constant < 0 + ratio = float('inf') # TODO + expr.terms.each_pair do |v, cd| + if cd > 0 && v.is_pivotable + zc = z_row.coefficient_for(v) + r = zc / cd + if r < ratio # JS difference? + entry_var = v + ratio = r + end + end + end + if ratio == float('inf') # TODO + raise InternalError("ratio == nil (MAX_VALUE) in dual_optimize") + end + self.pivot(entry_var, exit_var) + end + end + end + end + +=begin + def optimize(self, z_var): + # print("optimize", z_var) + # print(self) + self.optimize_count = self.optimize_count + 1 + + z_row = self.rows[z_var] + entry_var = None + exit_var = None + + # print(self.objective) + # print(z_var) + # print(self.rows[self.objective]) + # print(self.rows[z_var]) + + while True: + objective_coeff = 0.0 + + # Not convinced the sort is correct here; but test suite + # doesn't pass reliably without it. + for v, c in sorted(z_row.terms.items(), key=lambda x: x[0].name): + # print('term check', v, v.is_pivotable, c) + if v.is_pivotable and c < objective_coeff: + # print('candidate found') + objective_coeff = c + entry_var = v + break; + + if objective_coeff >= -EPSILON or entry_var is None: + return + + # print('entry_var:', entry_var) + # print("objective_coeff:", objective_coeff) + + min_ratio = float('inf') + r = 0 + + for v in self.columns[entry_var]: + # print("checking", v) + if v.is_pivotable: + expr = self.rows[v] + coeff = expr.coefficient_for(entry_var) + # print('pivotable, coeff =', coeff) + if coeff < 0: + r = -expr.constant / coeff + if r < min_ratio: + min_ratio = r + exit_var = v + + if min_ratio == float('inf'): + raise RequiredFailure('Objective function is unbounded') + + self.pivot(entry_var, exit_var) + + # print(self) +=end + def pivot(entry_var, exit_var) + # print('pivot:',entry_var, exit_var) + if entry_var == nil + puts("WARN - entry_var is None") + end + if exit_var == nil + puts("WARN - exit_var is None") + end + p_expr = self.remove_row(exit_var) + p_expr.change_subject(exit_var, entry_var) + self.substitute_out(entry_var, p_expr) + self.add_row(entry_var, p_expr) + end + + def reset_stay_constants() + @stay_error_vars.each_pair do |p_var, m_var| + expr = @rows[p_var] + if !expr + expr = @rows[m_var] + end + if expr + expr.constant = 0.0 + end + end + end + + def set_external_variables() + @external_parametric_vars.each do |v| + if @rows[v] + continue + end + v.value = 0.0 + end + @external_rows.each do |v| + expr = @rows[v] + v.value = expr.constant + end + @needs_solving = false + end + + def insert_error_var(cn, var) + constraint_set = @error_vars[var] + if !constraint_set + constraint_set = Set.new() + @error_vars[cn] = constraint_set + end + constraint_set.add(var) + if !@error_vars[var] + @error_vars[var] = Set.new(var) # TODO: correct? + end + end + + end # class SimplexSolver +end #module Cassowary diff --git a/lib/cassowary/tableau.py b/lib/cassowary/tableau.py new file mode 100644 index 00000000..2b0e175e --- /dev/null +++ b/lib/cassowary/tableau.py @@ -0,0 +1,109 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + + +class Tableau(object): + def __init__(self): + # Map of variable to set of variables + self.columns = {} + + # Map of variable to LinearExpression + self.rows = {} + + # Set of Variables + self.infeasible_rows = set() + + # Set of Variables + self.external_rows = set() + + # Set of Variables. + self.external_parametric_vars = set() + + def __repr__(self): + parts = [] + parts.append('Tableau info:') + parts.append('Rows: %s (= %s constraints)' % (len(self.rows), len(self.rows) - 1)) + parts.append('Columns: %s' % len(self.columns)) + parts.append('Infeasible rows: %s' % len(self.infeasible_rows)) + parts.append('External basic variables: %s' % len(self.external_rows)) + parts.append('External parametric variables: %s' % len(self.external_parametric_vars)) + return '\n'.join(parts) + + def note_removed_variable(self, var, subject): + if subject: + self.columns[var].remove(subject) + + def note_added_variable(self, var, subject): + if subject: + self.columns.setdefault(var, set()).add(subject) + + def add_row(self, var, expr): + # print('add_row', var, expr) + self.rows[var] = expr + + for clv in expr.terms: + self.columns.setdefault(clv, set()).add(var) + if clv.is_external: + self.external_parametric_vars.add(clv) + + if var.is_external: + self.external_rows.add(var) + + # print(self) + + def remove_column(self, var): + rows = self.columns.pop(var, None) + + if rows: + for clv in rows: + expr = self.rows[clv] + expr.remove_variable(var) + + if var.is_external: + try: + self.external_rows.remove(var) + except KeyError: + pass + + try: + self.external_parametric_vars.remove(var) + except KeyError: + pass + + def remove_row(self, var): + # print("remove_row", var) + expr = self.rows.pop(var) + + for clv in expr.terms.keys(): + varset = self.columns[clv] + if varset: + # print("removing from varset", var) + varset.remove(var) + + try: + self.infeasible_rows.remove(var) + except KeyError: + pass + if var.is_external: + try: + self.external_rows.remove(var) + except KeyError: + pass + # print("remove_row returning", expr) + return expr + + def substitute_out(self, oldVar, expr): + varset = self.columns[oldVar] + for v in varset: + row = self.rows[v] + row.substitute_out(oldVar, expr, v, self) + if v.is_restricted and row.constant < 0.0: + self.infeasible_rows.add(v) + + if oldVar.is_external: + self.external_rows.add(oldVar) + try: + self.external_parametric_vars.remove(oldVar) + except KeyError: + pass + + del self.columns[oldVar] diff --git a/lib/cassowary/tableau.rb b/lib/cassowary/tableau.rb new file mode 100644 index 00000000..ada336b9 --- /dev/null +++ b/lib/cassowary/tableau.rb @@ -0,0 +1,177 @@ +module Cassowary + require 'set' + class Tableau + attr_accessor :colums,:rows,:infeasible_rows, :external_rows, :external_parametric_vars + + def initialize + # Map of variable to set of variables + @columns = {} + + # Map of variable to LinearExpression + @rows = {} + + # Set of Variables + @infeasible_rows = Set.new() + + # Set of Variables + @external_rows = Set.new() + + # Set of Variables. + @external_parametric_vars = Set.new() + end +=begin + def __repr__(self): + parts = [] + parts.append('Tableau info:') + parts.append('Rows: %s (= %s constraints)' % (len(self.rows), len(self.rows) - 1)) + parts.append('Columns: %s' % len(self.columns)) + parts.append('Infeasible rows: %s' % len(self.infeasible_rows)) + parts.append('External basic variables: %s' % len(self.external_rows)) + parts.append('External parametric variables: %s' % len(self.external_parametric_vars)) + return '\n'.join(parts) +=end + def note_removed_variable(var, subject) + if subject + @columns[var].delete(subject) + end + end + + # TODO: is my version of Dict.setdefault correct? + def note_added_variable(var, subject) + if subject + # python: self.columns.setdefault(var, set()).add(subject) + @columns[var] || @columns[var] = Set.new(subject) + end + end + +=begin + def add_row(self, var, expr): + # print('add_row', var, expr) + for clv in expr.terms: + self.columns.setdefault(clv, set()).add(var) + if clv.is_external: + self.external_parametric_vars.add(clv) + + if var.is_external: + self.external_rows.add(var) + + # print(self) +=end + + def add_row(var, expr) + @rows[var] = expr + expr.terms.each do |clv| + @columns[clv] || @columns[clv] = Set.new(var) + if clv.is_external + @external_parametric_vars.add(clv) + end + end + if var.is_external + @external_rows.add(var) + end + end +=begin + def remove_column(self, var): + rows = self.columns.pop(var, None) + + if rows: + for clv in rows: + expr = self.rows[clv] + expr.remove_variable(var) + + if var.is_external: + try: + self.external_rows.remove(var) + except KeyError: + pass + + try: + self.external_parametric_vars.remove(var) + except KeyError: + pass +=end + def remove_column(var) + rows = @columns.delete(var) + if rows + rows.each do |clv| + expr = @rows[clv] + expr.remove_variable(var) + end + end + if var.is_external + @external_rows.delete(var) + @external_parametric_vars.delete(var) + end + end +=begin + def remove_row(self, var): + # print("remove_row", var) + expr = self.rows.pop(var) + + for clv in expr.terms.keys(): + varset = self.columns[clv] + if varset: + # print("removing from varset", var) + varset.remove(var) + + try: + self.infeasible_rows.remove(var) + except KeyError: + pass + if var.is_external: + try: + self.external_rows.remove(var) + except KeyError: + pass + # print("remove_row returning", expr) + return expr +=end + def remove_row(var) + expr = @rows.delete(var) + expr.terms.each_key do |clv| + varset = @columns[clv] + if varset + varset.remove(var) + end + end + @infeasible_rows.delete(var) + if var.is_external + @external_rows.delete(var) + end + return expr + end +=begin + def substitute_out(self, oldVar, expr): + varset = self.columns[oldVar] + for v in varset: + row = self.rows[v] + row.substitute_out(oldVar, expr, v, self) + if v.is_restricted and row.constant < 0.0: + self.infeasible_rows.add(v) + + if oldVar.is_external: + self.external_rows.add(oldVar) + try: + self.external_parametric_vars.remove(oldVar) + except KeyError: + pass + + del self.columns[oldVar] +=end + def substitute_out(oldvar, expr) + varset = @columns[oldvar] + varset.each_value do |v| # TODO: correct? + rows = @rows[v] + row.substitute_out(oldvar, expr, v, self) + if v.is_restricted && row.constant < 0.0 + @infeasible_rows.add(oldvar) + end + if oldvar.is_external + @external_rows.add(oldvar) + @external_parametrics_vars.delete(oldvar) # TODO correct? + end + end + @columns.delete(oldvar) + end + end # tableau class +end # module Cassowary diff --git a/lib/cassowary/test.rb b/lib/cassowary/test.rb new file mode 100644 index 00000000..36c5ac34 --- /dev/null +++ b/lib/cassowary/test.rb @@ -0,0 +1,31 @@ +# really just tests if syntax is acceptable for loading. +# using it requires the other test pass. +require_relative 'utils' +require_relative 'tableau' +require_relative 'variable' +require_relative 'expression' +require_relative 'constraint' +include Cassowary + +# From utils.rb: +puts approx_equal 10, 10.00000001 # false +puts approx_equal 10, 10.0000000001 # true +puts approx_equal 10.1, 10.18, 0.01 # false +# From error.py +begin + raise RequiredFailure +rescue => e + puts e +end +# From edit_info.py +puts EditInfo.new(1,2,3,4,5).inspect +# From expression.py +v = AbstractVariable.new('foo') +puts "v is #{v.inspect}" +puts "mult #{v * 12.34}" +# From tableau.py +tb = Tableau.new +puts tb.inspect + +require_relative 'simplex_solver' + diff --git a/lib/cassowary/tests/__init__.py b/lib/cassowary/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/cassowary/tests/test_constraint.py b/lib/cassowary/tests/test_constraint.py new file mode 100644 index 00000000..f915771d --- /dev/null +++ b/lib/cassowary/tests/test_constraint.py @@ -0,0 +1,166 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +from unittest import TestCase +if not hasattr(TestCase, 'assertIsNotNone'): + # For Python2.6 compatibility + from unittest2 import TestCase + +from cassowary import Variable, SimplexSolver, STRONG, WEAK + +# Internals +from cassowary.expression import Expression, Constraint +from cassowary.utils import approx_equal + + +class ConstraintTestCase(TestCase): + + def assertExpressionEqual(self, expr1, expr2): + + if not isinstance(expr1, Expression): + return False + + if not isinstance(expr2, Expression): + return False + + if not approx_equal(expr1.constant, expr2.constant): + return False + + if len(expr1.terms) != len(expr2.terms): + return False + + if any(expr2.terms.get(var) != value for var, value in expr1.terms.items()): + return False + + return True + + def test_from_constant(self): + "Constraint can be constructed from a constant expression" + ex = Expression(constant=10) + + # Constant value is ported to a float. + self.assertIsInstance(ex.constant, float) + self.assertAlmostEqual(ex.constant, 10.0) + + c1 = Constraint(ex) + + self.assertExpressionEqual(c1.expression, ex) + self.assertFalse(c1.is_inequality) + + c2 = Constraint(10) + + self.assertAlmostEqual(c2.expression.constant, 10) + self.assertTrue(c2.expression.is_constant) + self.assertFalse(c2.is_inequality) + + def test_variable_expression(self): + "Variable expressions can be constructed" + x = Variable('x', 167) + y = Variable('y', 2) + cly = Expression(y) + cly.add_expression(x) + + # def test_equation_from_variable_expression(self): + # "Constraints can be constructed from variables and expressions" + # x = Variable(name='x', value=167) + # cly = Expression(constant=2) + # eq = Constraint(x, Constraint.EQ, cly) + + def test_strength(self): + "Solvers should handle strengths correctly" + solver = SimplexSolver() + x = Variable(name='x', value=10) + y = Variable(name='y', value=20) + z = Variable(name='z', value=1) + w = Variable(name='w', value=1) + + # Default weights. + e0 = Constraint(x, Constraint.EQ, y) + solver.add_stay(y) + + solver.add_constraint(e0) + self.assertAlmostEqual(x.value, 20.0) + self.assertAlmostEqual(y.value, 20.0) + + # Add a weak constraint. + e1 = Constraint(x, Constraint.EQ, z, strength=WEAK) + solver.add_stay(x) + solver.add_constraint(e1) + self.assertAlmostEqual(x.value, 20.0) + self.assertAlmostEqual(z.value, 20.0) + + # Add a strong constraint. + e2 = Constraint(z, Constraint.EQ, w, strength=STRONG) + solver.add_stay(w) + solver.add_constraint(e2) + self.assertAlmostEqual(w.value, 1.0) + self.assertAlmostEqual(z.value, 1.0) + + def test_numbers_in_place_of_variables(self): + v = Variable(name='v', value=22) + eq = Constraint(v, Constraint.EQ, 5) + self.assertExpressionEqual(eq.expression, 5 - v) + + def test_equations_in_place_of_variables(self): + e = Expression(constant=10) + v = Variable(name='v', value=22) + eq = Constraint(e, Constraint.EQ, v) + + self.assertExpressionEqual(eq.expression, 10 - v) + + def test_works_with_nested_expressions(self): + e1 = Expression(constant=10) + e2 = Expression(Variable(name='z', value=10), 2, 4) + eq = Constraint(e1, Constraint.EQ, e2) + + self.assertExpressionEqual(eq.expression, e1 - e2) + + def test_inequality_expression_instantiation(self): + e = Expression(constant = 10) + ieq = Constraint(e) + self.assertExpressionEqual(ieq.expression, e) + + def test_operator_arguments_to_inequality(self): + v1 = Variable(name='v1', value=10) + v2 = Variable(name='v2', value=5) + ieq = Constraint(v1, Constraint.GEQ, v2) + self.assertExpressionEqual(ieq.expression, v1 - v2) + + ieq = Constraint(v1, Constraint.LEQ, v2) + self.assertExpressionEqual(ieq.expression, v2 - v1) + + def test_expression_with_variable_and_operators(self): + v = Variable(name='v', value=10) + ieq = Constraint(v, Constraint.GEQ, 5) + self.assertExpressionEqual(ieq.expression, v - 5) + + ieq = Constraint(v, Constraint.LEQ, 5) + self.assertExpressionEqual(ieq.expression, 5 - v) + + def test_expression_with_reused_variables(self): + e1 = Expression(constant=10) + e2 = Expression(Variable(name='c', value=10), 2, 4) + ieq = Constraint(e1, Constraint.GEQ, e2) + + self.assertExpressionEqual(ieq.expression, e1 - e2) + + ieq = Constraint(e1, Constraint.LEQ, e2) + self.assertExpressionEqual(ieq.expression, e2 - e1) + + def test_constructor_with_variable_operator_expression_args(self): + v = Variable(name='v', value=10) + e = Expression(Variable(name='x', value=5), 2, 4) + ieq = Constraint(v, Constraint.GEQ, e) + + self.assertExpressionEqual(ieq.expression, v - e) + + ieq = Constraint(v, Constraint.LEQ, e) + self.assertExpressionEqual(ieq.expression, e - v) + + def test_constructor_with_variable_operator_expression_args2(self): + v = Variable(name='v', value=10) + e = Expression(Variable(name='x', value=5), 2, 4) + ieq = Constraint(e, Constraint.GEQ, v) + self.assertExpressionEqual(ieq.expression, e - v) + + ieq = Constraint(e, Constraint.LEQ, v) + self.assertExpressionEqual(ieq.expression, v - e) diff --git a/lib/cassowary/tests/test_end_to_end.py b/lib/cassowary/tests/test_end_to_end.py new file mode 100644 index 00000000..d948843d --- /dev/null +++ b/lib/cassowary/tests/test_end_to_end.py @@ -0,0 +1,770 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +import random +from unittest import TestCase +if not hasattr(TestCase, 'assertIsNotNone'): + # For Python2.6 compatibility + from unittest2 import TestCase + +from cassowary import RequiredFailure, SimplexSolver, STRONG, WEAK, MEDIUM, REQUIRED, Variable + +# Internals +from cassowary.expression import Constraint +from cassowary.utils import approx_equal + + +class EndToEndTestCase(TestCase): + def test_simple(self): + solver = SimplexSolver() + + x = Variable('x', 167) + y = Variable('y', 2) + eq = Constraint(x, Constraint.EQ, y) + + solver.add_constraint(eq) + self.assertAlmostEqual(x.value, y.value) + self.assertAlmostEqual(x.value, 0) + self.assertAlmostEqual(y.value, 0) + + def test_stay(self): + x = Variable('x', 5) + y = Variable('y', 10) + + solver = SimplexSolver() + solver.add_stay(x) + solver.add_stay(y) + + self.assertAlmostEqual(x.value, 5) + self.assertAlmostEqual(y.value, 10) + + def test_variable_geq_constant(self): + solver = SimplexSolver() + + x = Variable('x', 10) + ieq = Constraint(x, Constraint.GEQ, 100) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 100) + + def test_variable_leq_constant(self): + solver = SimplexSolver() + + x = Variable('x', 100) + ieq = Constraint(x, Constraint.LEQ, 10) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 10) + + def test_variable_equal_constant(self): + solver = SimplexSolver() + + x = Variable('x', 10) + eq = Constraint(100, Constraint.EQ, x) + solver.add_constraint(eq) + + self.assertAlmostEqual(x.value, 100) + + def test_constant_geq_variable(self): + # 10 >= x + solver = SimplexSolver() + + x = Variable('x', 100) + ieq = Constraint(10, Constraint.GEQ, x) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 10) + + def test_constant_leq_variable(self): + # 100 <= x + solver = SimplexSolver() + + x = Variable('x', 10) + ieq = Constraint(100, Constraint.LEQ, x) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 100) + + def test_geq_with_stay(self): + # stay width + # right >= 100 + solver = SimplexSolver() + + # x = 10 + x = Variable('x', 10) + # width = 10 + width = Variable('width', 10) + # right = x + width + right = x + width + # right >= 100 + ieq = Constraint(right, Constraint.GEQ, 100) + solver.add_stay(width) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 90) + self.assertAlmostEqual(width.value, 10) + + def test_leq_with_stay(self): + # stay width + # 100 <= right + solver = SimplexSolver() + + x = Variable('x', 10) + width = Variable('width', 10) + right = x + width + ieq = Constraint(100, Constraint.LEQ, right) + + solver.add_stay(width) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 90) + self.assertAlmostEqual(width.value, 10) + + def test_equality_with_stay(self): + # stay width, rightMin + # right >= rightMin + solver = SimplexSolver() + + x = Variable('x', 10) + width = Variable('width', 10) + rightMin = Variable('rightMin', 100) + + right = x + width + + eq = Constraint(right, Constraint.EQ, rightMin) + + solver.add_stay(width) + solver.add_stay(rightMin) + solver.add_constraint(eq) + + self.assertAlmostEqual(x.value, 90) + self.assertAlmostEqual(width.value, 10) + + def test_geq_with_variable(self): + # stay width, rightMin + # right >= rightMin + solver = SimplexSolver() + + x = Variable('x', 10) + width = Variable('width', 10) + rightMin = Variable('rightMin', 100) + + right = x + width + + ieq = Constraint(right, Constraint.GEQ, rightMin) + + solver.add_stay(width) + solver.add_stay(rightMin) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 90) + self.assertAlmostEqual(width.value, 10) + + def test_leq_with_variable(self): + # stay width + # right >= rightMin + solver = SimplexSolver() + + x = Variable('x', 10) + width = Variable('width', 10) + rightMin = Variable('rightMin', 100) + + right = x + width + + ieq = Constraint(rightMin, Constraint.LEQ, right) + + solver.add_stay(width) + solver.add_stay(rightMin) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x.value, 90) + self.assertAlmostEqual(width.value, 10) + + def test_equality_with_expression(self): + # stay width, rightMin + # right >= rightMin + solver = SimplexSolver() + + x1 = Variable('x1', 10) + width1 = Variable('width1', 10) + right1 = x1 + width1 + + x2 = Variable('x2', 100) + width2 = Variable('width2', 10) + right2 = x2 + width2 + + eq = Constraint(right1, Constraint.EQ, right2) + + solver.add_stay(width1) + solver.add_stay(width2) + solver.add_stay(x2) + solver.add_constraint(eq) + + self.assertAlmostEqual(x1.value, 100) + self.assertAlmostEqual(x2.value, 100) + self.assertAlmostEqual(width1.value, 10) + self.assertAlmostEqual(width2.value, 10) + + def test_geq_with_expression(self): + # stay width, rightMin + # right >= rightMin + solver = SimplexSolver() + + x1 = Variable('x1', 10) + width1 = Variable('width1', 10) + right1 = x1 + width1 + + x2 = Variable('x2', 100) + width2 = Variable('width2', 10) + right2 = x2 + width2 + + ieq = Constraint(right1, Constraint.GEQ, right2) + + solver.add_stay(width1) + solver.add_stay(width2) + solver.add_stay(x2) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x1.value, 100) + + def test_leq_with_expression(self): + # stay width, rightMin + # right >= rightMin + solver = SimplexSolver() + + x1 = Variable('x1', 10) + width1 = Variable('width1', 10) + right1 = x1 + width1 + + x2 = Variable('x2', 100) + width2 = Variable('width2', 10) + right2 = x2 + width2 + + ieq = Constraint(right2, Constraint.LEQ, right1) + + solver.add_stay(width1) + solver.add_stay(width2) + solver.add_stay(x2) + solver.add_constraint(ieq) + + self.assertAlmostEqual(x1.value, 100) + + def test_delete1(self): + solver = SimplexSolver() + x = Variable('x') + cbl = Constraint(x, Constraint.EQ, 100, WEAK) + solver.add_constraint(cbl) + + c10 = Constraint(x, Constraint.LEQ, 10) + c20 = Constraint(x, Constraint.LEQ, 20) + solver.add_constraint(c10) + solver.add_constraint(c20) + self.assertAlmostEqual(x.value, 10) + + solver.remove_constraint(c10) + self.assertAlmostEqual(x.value, 20) + + solver.remove_constraint(c20) + self.assertAlmostEqual(x.value, 100) + + c10again = Constraint(x, Constraint.LEQ, 10) + solver.add_constraint(c10) + solver.add_constraint(c10again) + self.assertAlmostEqual(x.value, 10) + + solver.remove_constraint(c10) + self.assertAlmostEqual(x.value, 10) + + solver.remove_constraint(c10again) + self.assertAlmostEqual(x.value, 100) + + def test_delete2(self): + solver = SimplexSolver() + x = Variable('x') + y = Variable('y') + + solver.add_constraint(Constraint(x, Constraint.EQ, 100, WEAK)) + solver.add_constraint(Constraint(y, Constraint.EQ, 120, STRONG)) + c10 = Constraint(x, Constraint.LEQ, 10) + c20 = Constraint(x, Constraint.LEQ, 20) + solver.add_constraint(c10) + solver.add_constraint(c20) + self.assertAlmostEqual(x.value, 10) + self.assertAlmostEqual(y.value, 120) + + solver.remove_constraint(c10) + self.assertAlmostEqual(x.value, 20) + self.assertAlmostEqual(y.value, 120) + + cxy = Constraint(x * 2, Constraint.EQ, y) + solver.add_constraint(cxy) + self.assertAlmostEqual(x.value, 20) + self.assertAlmostEqual(y.value, 40) + + solver.remove_constraint(c20) + self.assertAlmostEqual(x.value, 60) + self.assertAlmostEqual(y.value, 120) + + solver.remove_constraint(cxy) + self.assertAlmostEqual(x.value, 100) + self.assertAlmostEqual(y.value, 120) + + def test_casso1(self): + solver = SimplexSolver() + x = Variable('x') + y = Variable('y') + + solver.add_constraint(Constraint(x, Constraint.LEQ, y)) + solver.add_constraint(Constraint(y, Constraint.EQ, x + 3)) + solver.add_constraint(Constraint(x, Constraint.EQ, 10, WEAK)) + solver.add_constraint(Constraint(y, Constraint.EQ, 10, WEAK)) + + self.assertTrue( + (approx_equal(x.value, 10) and approx_equal(y.value, 13)) or + (approx_equal(x.value, 7) and approx_equal(y.value, 10)) + ) + + def test_inconsistent1(self): + solver = SimplexSolver() + x = Variable('x') + # x = 10 + solver.add_constraint(Constraint(x, Constraint.EQ, 10)) + # x = 5 + with self.assertRaises(RequiredFailure): + solver.add_constraint(Constraint(x, Constraint.EQ, 5)) + + def test_inconsistent2(self): + solver = SimplexSolver() + x = Variable('x') + solver.add_constraint(Constraint(x, Constraint.GEQ, 10)) + + with self.assertRaises(RequiredFailure): + solver.add_constraint(Constraint(x, Constraint.LEQ, 5)) + + def test_inconsistent3(self): + solver = SimplexSolver() + w = Variable('w') + x = Variable('x') + y = Variable('y') + z = Variable('z') + solver.add_constraint(Constraint(w, Constraint.GEQ, 10)) + solver.add_constraint(Constraint(x, Constraint.GEQ, w)) + solver.add_constraint(Constraint(y, Constraint.GEQ, x)) + solver.add_constraint(Constraint(z, Constraint.GEQ, y)) + solver.add_constraint(Constraint(z, Constraint.GEQ, 8)) + + with self.assertRaises(RequiredFailure): + solver.add_constraint(Constraint(z, Constraint.LEQ, 4)) + + def test_inconsistent4(self): + solver = SimplexSolver() + x = Variable('x') + y = Variable('y') + # x = 10 + solver.add_constraint(Constraint(x, Constraint.EQ, 10)) + # x = y + solver.add_constraint(Constraint(x, Constraint.EQ, y)) + # y = 5. Should fail. + with self.assertRaises(RequiredFailure): + solver.add_constraint(Constraint(y, Constraint.EQ, 5)) + + def test_multiedit1(self): + # This test stresses the edit session stack. begin_edit() starts a new + # "edit variable group" and "end_edit" closes it, leaving only the + # previously opened edit variables still active. + x = Variable('x') + y = Variable('y') + w = Variable('w') + h = Variable('h') + solver = SimplexSolver() + + # Add some stays + solver.add_stay(x) + solver.add_stay(y) + solver.add_stay(w) + solver.add_stay(h) + + # start an editing session + solver.add_edit_var(x) + solver.add_edit_var(y) + + with solver.edit(): + solver.suggest_value(x, 10) + solver.suggest_value(y, 20) + + # Force the system to resolve. + solver.resolve() + + self.assertAlmostEqual(x.value, 10) + self.assertAlmostEqual(y.value, 20) + self.assertAlmostEqual(w.value, 0) + self.assertAlmostEqual(h.value, 0) + + # Open a second set of variables for editing + solver.add_edit_var(w) + solver.add_edit_var(h) + + with solver.edit(): + solver.suggest_value(w, 30) + solver.suggest_value(h, 40) + + # Close the second set... + self.assertAlmostEqual(x.value, 10) + self.assertAlmostEqual(y.value, 20) + self.assertAlmostEqual(w.value, 30) + self.assertAlmostEqual(h.value, 40) + + # Now make sure the first set can still be edited + solver.suggest_value(x, 50) + solver.suggest_value(y, 60) + + self.assertAlmostEqual(x.value, 50) + self.assertAlmostEqual(y.value, 60) + self.assertAlmostEqual(w.value, 30) + self.assertAlmostEqual(h.value, 40) + + def test_multiedit2(self): + + x = Variable('x') + y = Variable('y') + w = Variable('w') + h = Variable('h') + + solver = SimplexSolver() + solver.add_stay(x) + solver.add_stay(y) + solver.add_stay(w) + solver.add_stay(h) + solver.add_edit_var(x) + solver.add_edit_var(y) + + solver.begin_edit() + solver.suggest_value(x, 10) + solver.suggest_value(y, 20) + solver.resolve() + solver.end_edit() + + self.assertAlmostEqual(x.value, 10) + self.assertAlmostEqual(y.value, 20) + self.assertAlmostEqual(w.value, 0) + self.assertAlmostEqual(h.value, 0) + + solver.add_edit_var(w) + solver.add_edit_var(h) + + solver.begin_edit() + solver.suggest_value(w, 30) + solver.suggest_value(h, 40) + solver.end_edit() + + self.assertAlmostEqual(x.value, 10) + self.assertAlmostEqual(y.value, 20) + self.assertAlmostEqual(w.value, 30) + self.assertAlmostEqual(h.value, 40) + + solver.add_edit_var(x) + solver.add_edit_var(y) + + solver.begin_edit() + solver.suggest_value(x, 50) + solver.suggest_value(y, 60) + solver.end_edit() + + self.assertAlmostEqual(x.value, 50) + self.assertAlmostEqual(y.value, 60) + self.assertAlmostEqual(w.value, 30) + self.assertAlmostEqual(h.value, 40) + + def test_multiedit3(self): + MIN = 100 + MAX = 500 + + width = Variable('width') + height = Variable('height') + top = Variable('top') + bottom = Variable('bottom') + left = Variable('left') + right = Variable('right') + + solver = SimplexSolver() + + iw = Variable('window_innerWidth', random.randrange(MIN, MAX)) + ih = Variable('window_innerHeight', random.randrange(MIN, MAX)) + + solver.add_constraint(Constraint(width, Constraint.EQ, iw, strength=STRONG, weight=0.0)) + solver.add_constraint(Constraint(height, Constraint.EQ, ih, strength=STRONG, weight=0.0)) + solver.add_constraint(Constraint(top, Constraint.EQ, 0, strength=WEAK, weight=0.0)) + solver.add_constraint(Constraint(left, Constraint.EQ, 0, strength=WEAK, weight=0.0)) + solver.add_constraint(Constraint(bottom, Constraint.EQ, top + height, strength=MEDIUM, weight=0.0)) + # Right is at least left + width + solver.add_constraint(Constraint(right, Constraint.EQ, left + width, strength=MEDIUM, weight=0.0)) + solver.add_stay(iw) + solver.add_stay(ih) + + # Propegate viewport size changes. + for i in range(0, 30): + + # Measurement should be cheap here. + iwv = random.randrange(MIN, MAX) + ihv = random.randrange(MIN, MAX) + + solver.add_edit_var(iw) + solver.add_edit_var(ih) + + with solver.edit(): + solver.suggest_value(iw, iwv) + solver.suggest_value(ih, ihv) + # solver.resolve() + + self.assertAlmostEqual(top.value, 0) + self.assertAlmostEqual(left.value, 0) + self.assertLessEqual(bottom.value, MAX) + self.assertGreaterEqual(bottom.value, MIN) + self.assertLessEqual(right.value, MAX) + self.assertGreaterEqual(right.value, MIN) + + def test_error_weights(self): + solver = SimplexSolver() + + x = Variable('x', 100) + y = Variable('y', 200) + z = Variable('z', 50) + + self.assertAlmostEqual(x.value, 100) + self.assertAlmostEqual(y.value, 200) + self.assertAlmostEqual(z.value, 50) + + solver.add_constraint(Constraint(z, Constraint.EQ, x, WEAK)) + solver.add_constraint(Constraint(x, Constraint.EQ, 20, WEAK)) + solver.add_constraint(Constraint(y, Constraint.EQ, 200, STRONG)) + + self.assertAlmostEqual(x.value, 20) + self.assertAlmostEqual(y.value, 200) + self.assertAlmostEqual(z.value, 20) + + solver.add_constraint(Constraint(z + 150, Constraint.LEQ, y, MEDIUM)) + + self.assertAlmostEqual(x.value, 20) + self.assertAlmostEqual(y.value, 200) + self.assertAlmostEqual(z.value, 20) + + def test_quadrilateral(self): + "A simple version of the quadrilateral test" + + solver = SimplexSolver() + + class Point(object): + def __init__(self, identifier, x, y): + self.x = Variable('x' + identifier, x) + self.y = Variable('y' + identifier, y) + + def __repr__(self): + return u'(%s, %s)' % (self.x.value, self.y.value) + + __hash__ = object.__hash__ + + def __eq__(self, other): + return self.x.value == other[0] and self.y.value == other[1] + + points = [ + Point('0', 10, 10), + Point('1', 10, 200), + Point('2', 200, 200), + Point('3', 200, 10), + + Point('m0', 0, 0), + Point('m1', 0, 0), + Point('m2', 0, 0), + Point('m3', 0, 0), + ] + midpoints = points[4:] + + # Add point stays + weight = 1.0 + multiplier = 2.0 + for point in points[:4]: + solver.add_stay(point.x, WEAK, weight) + solver.add_stay(point.y, WEAK, weight) + weight = weight * multiplier + + for start, end in [(0, 1), (1, 2), (2, 3), (3, 0)]: + cle = (points[start].x + points[end].x) / 2 + cleq = midpoints[start].x == cle + solver.add_constraint(cleq) + cle = (points[start].y + points[end].y) / 2 + cleq = midpoints[start].y == cle + solver.add_constraint(cleq) + + cle = points[0].x + 20 + solver.add_constraint(cle <= points[2].x) + solver.add_constraint(cle <= points[3].x) + + cle = points[1].x + 20 + solver.add_constraint(cle <= points[2].x) + solver.add_constraint(cle <= points[3].x) + + cle = points[0].y + 20 + solver.add_constraint(cle <= points[1].y) + solver.add_constraint(cle <= points[2].y) + + cle = points[3].y + 20 + solver.add_constraint(cle <= points[1].y) + solver.add_constraint(cle <= points[2].y) + + for point in points: + solver.add_constraint(point.x >= 0) + solver.add_constraint(point.y >= 0) + + solver.add_constraint(point.x <= 500) + solver.add_constraint(point.y <= 500) + + # Check the initial answers + + self.assertEqual(points[0], (10.0, 10.0)) + self.assertEqual(points[1], (10.0, 200.0)) + self.assertEqual(points[2], (200.0, 200.0)) + self.assertEqual(points[3], (200.0, 10.0)) + self.assertEqual(points[4], (10.0, 105.0)) + self.assertEqual(points[5], (105.0, 200.0)) + self.assertEqual(points[6], (200.0, 105.0)) + self.assertEqual(points[7], (105.0, 10.0)) + + # Now move point 2 to a new location + + solver.add_edit_var(points[2].x) + solver.add_edit_var(points[2].y) + + solver.begin_edit() + + solver.suggest_value(points[2].x, 300) + solver.suggest_value(points[2].y, 400) + + solver.end_edit() + + # Check that the other points have been moved. + self.assertEqual(points[0], (10.0, 10.0)) + self.assertEqual(points[1], (10.0, 200.0)) + self.assertEqual(points[2], (300.0, 400.0)) + self.assertEqual(points[3], (200.0, 10.0)) + self.assertEqual(points[4], (10.0, 105.0)) + self.assertEqual(points[5], (155.0, 300.0)) + self.assertEqual(points[6], (250.0, 205.0)) + self.assertEqual(points[7], (105.0, 10.0)) + + def test_buttons(self): + "A test of a horizontal layout of two buttons on a screen." + + class Button(object): + def __init__(self, identifier): + self.left = Variable('left' + identifier, 0) + self.width = Variable('width' + identifier, 0) + + def __repr__(self): + return u'(%s:%s)' % (self.left.value, self.width.value) + + solver = SimplexSolver() + + b1 = Button('b1') + b2 = Button('b2') + left_limit = Variable('left', 0) + right_limit = Variable('width', 0) + + left_limit.value = 0 + solver.add_stay(left_limit, REQUIRED) + stay = solver.add_stay(right_limit, WEAK) + + # The two buttons are the same width + solver.add_constraint(b1.width == b2.width) + + # b1 starts 50 from the left margin. + solver.add_constraint(b1.left == left_limit + 50) + + # b2 ends 50 from the right margin + solver.add_constraint(left_limit + right_limit == b2.left + b2.width + 50) + + # b2 starts at least 100 from the end of b1 + solver.add_constraint(b2.left >= (b1.left + b1.width + 100)) + + # b1 has a minimum width of 87 + solver.add_constraint(b1.width >= 87) + + # b1's preferred width is 87 + solver.add_constraint(b1.width == 87, STRONG) + + # b2's minimum width is 113 + solver.add_constraint(b2.width >= 113) + + # b2's preferred width is 113 + solver.add_constraint(b2.width == 113, STRONG) + + # Without imposign a stay on the right, right_limit will be the minimum width for the layout + self.assertAlmostEqual(b1.left.value, 50.0) + self.assertAlmostEqual(b1.width.value, 113.0) + self.assertAlmostEqual(b2.left.value, 263.0) + self.assertAlmostEqual(b2.width.value, 113.0) + self.assertAlmostEqual(right_limit.value, 426.0) + + # The window is 500 pixels wide. + right_limit.value = 500 + stay = solver.add_stay(right_limit, REQUIRED) + self.assertAlmostEqual(b1.left.value, 50.0) + self.assertAlmostEqual(b1.width.value, 113.0) + self.assertAlmostEqual(b2.left.value, 337.0) + self.assertAlmostEqual(b2.width.value, 113.0) + self.assertAlmostEqual(right_limit.value, 500.0) + solver.remove_constraint(stay) + + # Expand to 700 pixels + right_limit.value = 700 + stay = solver.add_stay(right_limit, REQUIRED) + self.assertAlmostEqual(b1.left.value, 50.0) + self.assertAlmostEqual(b1.width.value, 113.0) + self.assertAlmostEqual(b2.left.value, 537.0) + self.assertAlmostEqual(b2.width.value, 113.0) + self.assertAlmostEqual(right_limit.value, 700.0) + solver.remove_constraint(stay) + + # Contract to 600 + right_limit.value = 600 + stay = solver.add_stay(right_limit, REQUIRED) + self.assertAlmostEqual(b1.left.value, 50.0) + self.assertAlmostEqual(b1.width.value, 113.0) + self.assertAlmostEqual(b2.left.value, 437.0) + self.assertAlmostEqual(b2.width.value, 113.0) + self.assertAlmostEqual(right_limit.value, 600.0) + solver.remove_constraint(stay) + + def test_paper_example(self): + + solver = SimplexSolver() + + left = Variable('left') + middle = Variable('middle') + right = Variable('right') + + solver.add_constraint(middle == (left + right) / 2) + solver.add_constraint(right == left + 10) + solver.add_constraint(right <= 100) + solver.add_constraint(left >= 0) + + # Check that all the required constraints are true: + self.assertAlmostEqual((left.value + right.value) / 2, middle.value) + self.assertAlmostEqual(right.value, left.value + 10) + self.assertGreaterEqual(left.value, 0) + self.assertLessEqual(right.value, 100) + + # Set the middle value to a stay + middle.value = 45.0 + solver.add_stay(middle) + + # Check that all the required constraints are true: + self.assertAlmostEqual((left.value + right.value) / 2, middle.value) + self.assertAlmostEqual(right.value, left.value + 10) + self.assertGreaterEqual(left.value, 0) + self.assertLessEqual(right.value, 100) + + # But more than that - since we gave a position for middle, we know + # where all the points should be. + + self.assertAlmostEqual(left.value, 40) + self.assertAlmostEqual(middle.value, 45) + self.assertAlmostEqual(right.value, 50) diff --git a/lib/cassowary/tests/test_expression.py b/lib/cassowary/tests/test_expression.py new file mode 100644 index 00000000..238e4187 --- /dev/null +++ b/lib/cassowary/tests/test_expression.py @@ -0,0 +1,283 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +from unittest import TestCase +if not hasattr(TestCase, 'assertIsNotNone'): + # For Python2.6 compatibility + from unittest2 import TestCase + +from cassowary import InternalError, Variable + +# Internals +from cassowary.expression import Expression, SlackVariable + + +class ExpressionTestCase(TestCase): + def assertExpressionEqual(self, expr, output): + self.assertEqual(repr(expr), output) + + def test_empty_expression(self): + expr = Expression() + self.assertExpressionEqual(expr, '0.0') + self.assertAlmostEqual(expr.constant, 0.0) + self.assertEqual(len(expr.terms), 0) + + def test_full_expression(self): + x = Variable('x', 167) + + expr = Expression(x, 2, 3) + self.assertExpressionEqual(expr, '3.0 + 2.0*x[167.0]') + self.assertAlmostEqual(expr.constant, 3.0) + self.assertEqual(len(expr.terms), 1) + self.assertAlmostEqual(expr.terms.get(x), 2.0) + + def test_variable_expression(self): + x = Variable('x', 167) + + expr = Expression(x) + self.assertExpressionEqual(expr, 'x[167.0]') + self.assertAlmostEqual(expr.constant, 0.0) + self.assertEqual(len(expr.terms), 1) + self.assertAlmostEqual(expr.terms.get(x), 1.0) + + expr = Expression(x, 3) + self.assertExpressionEqual(expr, '3.0*x[167.0]') + self.assertAlmostEqual(expr.constant, 0.0) + self.assertEqual(len(expr.terms), 1) + self.assertAlmostEqual(expr.terms.get(x), 3.0) + + def test_constant_expression(self): + expr = Expression(constant=4) + self.assertExpressionEqual(expr, '4.0') + self.assertAlmostEqual(expr.constant, 4.0) + self.assertEqual(len(expr.terms), 0) + + def test_add(self): + x = Variable('x', 167) + y = Variable('y', 42) + + # Add a constant to an expression + self.assertExpressionEqual(Expression(x) + 2, '2.0 + x[167.0]') + self.assertExpressionEqual(3 + Expression(x), '3.0 + x[167.0]') + + # Add a variable to an expression + self.assertExpressionEqual(y + Expression(x), 'x[167.0] + y[42.0]') + self.assertExpressionEqual(Expression(x) + y, 'x[167.0] + y[42.0]') + + # Add an expression to an expression + self.assertExpressionEqual(Expression(x) + Expression(y), 'x[167.0] + y[42.0]') + self.assertExpressionEqual(Expression(x, 20, 2) + Expression(y, 10, 5), '7.0 + 20.0*x[167.0] + 10.0*y[42.0]') + + def test_sub(self): + x = Variable('x', 167) + y = Variable('y', 42) + + # Subtract a constant from an expression + self.assertExpressionEqual(Expression(x) - 2, '-2.0 + x[167.0]') + self.assertExpressionEqual(3 - Expression(x), '3.0 + -1.0*x[167.0]') + + # Subtract a variable from an expression + self.assertExpressionEqual(y - Expression(x), '-1.0*x[167.0] + y[42.0]') + self.assertExpressionEqual(Expression(x) - y, 'x[167.0] + -1.0*y[42.0]') + + # Subtract an expression from an expression + self.assertExpressionEqual(Expression(x) - Expression(y), 'x[167.0] + -1.0*y[42.0]') + self.assertExpressionEqual(Expression(x, 20, 2) - Expression(y, 10, 5), '-3.0 + 20.0*x[167.0] + -10.0*y[42.0]') + + def test_mul(self): + x = Variable('x', 167) + y = Variable('y', 42) + + # Multiply an expression by a constant + self.assertExpressionEqual(Expression(x) * 2, '2.0*x[167.0]') + self.assertExpressionEqual(3 * Expression(x), '3.0*x[167.0]') + + # Can't multiply an expression by a variable unless the expression is a constant + with self.assertRaises(TypeError): + y * Expression(x) + with self.assertRaises(TypeError): + Expression(x) * y + self.assertExpressionEqual(x * Expression(constant=2), '2.0*x[167.0]') + self.assertExpressionEqual(Expression(constant=3) * x, '3.0*x[167.0]') + + # Can't multiply an expression by an expression unless + # one of the expressions is a constant. + with self.assertRaises(TypeError): + Expression(x) * Expression(y) + with self.assertRaises(TypeError): + Expression(x, 20, 2) * Expression(y, 10, 5) + self.assertExpressionEqual(Expression(x, 20, 2) * Expression(constant=5), '10.0 + 100.0*x[167.0]') + self.assertExpressionEqual(Expression(x, 20) * Expression(constant=5), '100.0*x[167.0]') + self.assertExpressionEqual(Expression(constant=2) * Expression(y, 10, 5), '10.0 + 20.0*y[42.0]') + self.assertExpressionEqual(Expression(constant=2) * Expression(y, 10), '20.0*y[42.0]') + + def test_complex_math(self): + x = Variable('x', 167) + y = Variable('y', 2) + ex = 4 + x * 3 + 2 * y + self.assertExpressionEqual(ex, '4.0 + 3.0*x[167.0] + 2.0*y[2.0]') + + def test_clone(self): + v = Variable('v', 10) + expr = Expression(v, 20, 2) + clone = expr.clone() + + self.assertEqual(clone.constant, expr.constant) + self.assertEqual(len(clone.terms), len(expr.terms)) + self.assertEqual(clone.terms.get(v), 20) + + def test_is_constant(self): + e1 = Expression() + e2 = Expression(constant=10) + e3 = Expression(Variable('o', 10), 20) + e4 = Expression(Variable('o', 10), 20, 2) + + self.assertTrue(e1.is_constant) + self.assertTrue(e2.is_constant) + self.assertFalse(e3.is_constant) + self.assertFalse(e4.is_constant) + + def test_multiply(self): + v = Variable('v', 10) + expr = Expression(v, 20, 2) + expr.multiply(-1) + + self.assertExpressionEqual(expr, '-2.0 + -20.0*v[10.0]') + + def test_add_variable(self): + o = Variable('o', 10) + a = Expression(o, 20, 2) + v = Variable('v', 20) + + self.assertEqual(len(a.terms), 1) + self.assertAlmostEqual(a.terms.get(o), 20.0) + + # implicit coefficient of 1 + a.add_variable(v) + self.assertEqual(len(a.terms), 2) + self.assertAlmostEqual(a.terms.get(v), 1.0) + + # add again, with different coefficient + a.add_variable(v, 2.0) + self.assertEqual(len(a.terms), 2) + self.assertAlmostEqual(a.terms.get(v), 3.0) + + # add again, with resulting 0 coefficient. should remove the term. + a.add_variable(v, -3) + self.assertEqual(len(a.terms), 1) + self.assertIsNone(a.terms.get(v)) + + # try adding the removed term back, with 0 coefficient + a.add_variable(v, 0) + self.assertEqual(len(a.terms), 1) + self.assertIsNone(a.terms.get(v)) + + def test_add_expression_variable(self): + a = Expression(Variable('o', 10), 20, 2) + v = Variable('v', 20) + + # should work just like add_variable + a.add_expression(v, 2) + self.assertEqual(len(a.terms), 2) + self.assertEqual(a.terms.get(v), 2) + + def test_add_expression(self): + va = Variable('a', 10) + vb = Variable('b', 20) + vc = Variable('c', 5) + a = Expression(va, 20, 2) + + # different variable and implicit coefficient of 1, should make new term + a.add_expression(Expression(vb, 10, 5)) + self.assertEqual(len(a.terms), 2) + self.assertEqual(a.constant, 7) + self.assertEqual(a.terms.get(vb), 10) + + # same variable, should reuse existing term + a.add_expression(Expression(vb, 2, 5)) + self.assertEqual(len(a.terms), 2) + self.assertEqual(a.constant, 12) + self.assertEqual(a.terms.get(vb), 12) + + # another variable and a coefficient, + # should multiply the constant and all terms in the new expression + a.add_expression(Expression(vc, 1, 2), 2) + self.assertEqual(len(a.terms), 3) + self.assertEqual(a.constant, 16) + self.assertEqual(a.terms.get(vc), 2) + + def test_coefficient_for(self): + va = Variable('a', 10) + vb = Variable('b', 20) + a = Expression(va, 20, 2) + + self.assertEqual(a.coefficient_for(va), 20) + self.assertEqual(a.coefficient_for(vb), 0) + + def test_set_variable(self): + va = Variable('a', 10) + vb = Variable('b', 20) + a = Expression(va, 20, 2) + + # set existing variable + a.set_variable(va, 2) + self.assertEqual(len(a.terms), 1) + self.assertEqual(a.coefficient_for(va), 2) + + # set new variable + a.set_variable(vb, 2) + self.assertEqual(len(a.terms), 2) + self.assertEqual(a.coefficient_for(vb), 2) + + def test_any_pivotable_variable(self): + # t.e(c.InternalError, Expression(10), 'any_pivotable_variable') + e = Expression(constant=10) + with self.assertRaises(InternalError): + e.any_pivotable_variable() + # t.e(c.InternalError, Expression(10), 'any_pivotable_variable') + + va = Variable('a', 10) + vb = SlackVariable('slack', 1) + a = Expression(va, 20, 2) + + self.assertIsNone(a.any_pivotable_variable()) + + a.set_variable(vb, 2) + self.assertEqual(vb, a.any_pivotable_variable()) + + def test_substitute_out(self): + v1 = Variable('1', 20) + v2 = Variable('2', 2) + a = Expression(v1, 2, 2) # 2*v1 + 2 + + # new variable + a.substitute_out(v1, Expression(v2, 4, 4)) + self.assertEqual(a.constant, 10) + self.assertIsNone(a.terms.get(v1)) + self.assertEqual(a.terms.get(v2), 8) + + # existing variable + a.set_variable(v1, 1) + a.substitute_out(v2, Expression(v1, 2, 2)) + + self.assertEqual(a.constant, 26) + self.assertIsNone(a.terms.get(v2)) + self.assertEqual(a.terms.get(v1), 17) + + def test_new_subject(self): + v = Variable('v', 10) + e = Expression(v, 2, 5) + + self.assertEqual(e.new_subject(v), 0.5) + self.assertEqual(e.constant, -2.5) + self.assertIsNone(e.terms.get(v)) + self.assertTrue(e.is_constant) + + def test_change_subject(self): + va = Variable('a', 10) + vb = Variable('b', 5) + e = Expression(va, 2, 5) + + e.change_subject(vb, va) + self.assertEqual(e.constant, -2.5) + self.assertIsNone(e.terms.get(va)) + self.assertEqual(e.terms.get(vb), 0.5) diff --git a/lib/cassowary/tests/test_simplex_solver.py b/lib/cassowary/tests/test_simplex_solver.py new file mode 100644 index 00000000..71f9cf45 --- /dev/null +++ b/lib/cassowary/tests/test_simplex_solver.py @@ -0,0 +1,69 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +from unittest import TestCase +if not hasattr(TestCase, 'assertIsNotNone'): + # For Python2.6 compatibility + from unittest2 import TestCase + +from cassowary import Variable, SimplexSolver, STRONG, REQUIRED + +# internals +from cassowary.expression import Constraint + + +class SimplexSolverTestCase(TestCase): + def test_constructor(self): + "A solver can be constructed" + solver = SimplexSolver() + + self.assertEqual(len(solver.columns), 0) + self.assertEqual(len(solver.rows), 1) + self.assertEqual(len(solver.infeasible_rows), 0) + self.assertEqual(len(solver.external_rows), 0) + self.assertEqual(len(solver.external_parametric_vars), 0) + + def test_add_edit_var_required(self): + "Solver works with REQUIRED strength" + solver = SimplexSolver() + + a = Variable(name='a') + + solver.add_stay(a, STRONG, 0) + solver.resolve() + + self.assertEqual(a.value, 0) + + solver.add_edit_var(a, REQUIRED) + solver.begin_edit() + solver.suggest_value(a, 2) + solver.resolve() + + self.assertEqual(a.value, 2) + + def test_add_edit_var_required_after_suggestions(self): + "Solver works with REQUIRED strength after many suggestions" + solver = SimplexSolver() + a = Variable(name='a') + b = Variable(name='b') + + solver.add_stay(a, STRONG, 0) + solver.add_constraint(Constraint(a, Constraint.EQ, b, REQUIRED)) + solver.resolve() + + self.assertEqual(b.value, 0) + self.assertEqual(a.value, 0) + + solver.add_edit_var(a, REQUIRED) + solver.begin_edit() + solver.suggest_value(a, 2) + solver.resolve() + + self.assertEqual(a.value, 2) + self.assertEqual(b.value, 2) + + solver.suggest_value(a, 10) + solver.resolve() + + self.assertEqual(a.value, 10) + self.assertEqual(b.value, 10) + diff --git a/lib/cassowary/tests/test_tableau.py b/lib/cassowary/tests/test_tableau.py new file mode 100644 index 00000000..e28e52a5 --- /dev/null +++ b/lib/cassowary/tests/test_tableau.py @@ -0,0 +1,22 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +from unittest import TestCase +if not hasattr(TestCase, 'assertIsNotNone'): + # For Python2.6 compatibility + from unittest2 import TestCase + +# Internals +from cassowary.tableau import Tableau + + +class TableauTestCase(TestCase): + + def test_tableau(self): + "A Tableau can be constructed" + tableau = Tableau() + + self.assertEqual(len(tableau.columns), 0) + self.assertEqual(len(tableau.rows), 0) + self.assertEqual(len(tableau.infeasible_rows), 0) + self.assertEqual(len(tableau.external_rows), 0) + self.assertEqual(len(tableau.external_parametric_vars), 0) diff --git a/lib/cassowary/tests/test_variable.py b/lib/cassowary/tests/test_variable.py new file mode 100644 index 00000000..c3c3e424 --- /dev/null +++ b/lib/cassowary/tests/test_variable.py @@ -0,0 +1,133 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + +from unittest import TestCase +if not hasattr(TestCase, 'assertIsNotNone'): + # For Python2.6 compatibility + from unittest2 import TestCase + +from cassowary import Variable + +# Internals +from cassowary.expression import DummyVariable, SlackVariable, ObjectiveVariable + + +class VariableTestCase(TestCase): + def assertExpressionEqual(self, expr, output): + self.assertEqual(repr(expr), output) + + def test_Variable(self): + "A Variable can be constructed." + var = Variable('foo') + + self.assertEqual(var.name, 'foo') + self.assertAlmostEqual(var.value, 0.0) + self.assertFalse(var.is_dummy) + self.assertTrue(var.is_external) + self.assertFalse(var.is_pivotable) + self.assertFalse(var.is_restricted) + + self.assertEqual(repr(var), 'foo[0.0]') + + def test_Variable_with_value(self): + "A Variable can be constructed with a value." + var = Variable('foo', 42.0) + + self.assertEqual(var.name, 'foo') + self.assertAlmostEqual(var.value, 42.0) + self.assertFalse(var.is_dummy) + self.assertTrue(var.is_external) + self.assertFalse(var.is_pivotable) + self.assertFalse(var.is_restricted) + + self.assertEqual(repr(var), 'foo[42.0]') + + def test_DummyVariable(self): + "A Dummy Variable can be constructed." + var = DummyVariable(3) + + self.assertEqual(var.name, 'd3') + self.assertTrue(var.is_dummy) + self.assertFalse(var.is_external) + self.assertFalse(var.is_pivotable) + self.assertTrue(var.is_restricted) + + self.assertEqual(repr(var), 'd3:dummy') + + def test_SlackVariable(self): + "A Slack Variable can be constructed." + var = SlackVariable('foo', 3) + + self.assertEqual(var.name, 'foo3') + self.assertFalse(var.is_dummy) + self.assertFalse(var.is_external) + self.assertTrue(var.is_pivotable) + self.assertTrue(var.is_restricted) + + self.assertEqual(repr(var), 'foo3:slack') + + def test_ObjectiveVariable(self): + "An Objective Variable can be constructed." + var = ObjectiveVariable('foo') + + self.assertEqual(var.name, 'foo') + self.assertFalse(var.is_dummy) + self.assertFalse(var.is_external) + self.assertFalse(var.is_pivotable) + self.assertFalse(var.is_restricted) + + self.assertEqual(repr(var), 'foo:obj') + + def test_add(self): + x = Variable('x', 167) + + # Add a constant to an expression + self.assertExpressionEqual(x + 2, '2.0 + x[167.0]') + self.assertExpressionEqual(3 + x, '3.0 + x[167.0]') + + # Any other type fails + with self.assertRaises(TypeError): + x + object() + with self.assertRaises(TypeError): + object() + x + + def test_sub(self): + x = Variable('x', 167) + + # Subtract a constant from an expression + self.assertExpressionEqual(x - 2, '-2.0 + x[167.0]') + self.assertExpressionEqual(3 - x, '3.0 + -1.0*x[167.0]') + + # Any other type fails + with self.assertRaises(TypeError): + x - object() + with self.assertRaises(TypeError): + object() - x + + def test_mul(self): + x = Variable('x', 167) + + # Multiply an expression by a constant + self.assertExpressionEqual(x * 2, '2.0*x[167.0]') + self.assertExpressionEqual(3 * x, '3.0*x[167.0]') + + # Any other type fails + with self.assertRaises(TypeError): + x * object() + with self.assertRaises(TypeError): + object() * x + + def test_div(self): + x = Variable('x', 167) + + # Multiply an expression by a constant + self.assertExpressionEqual(x / 2, '0.5*x[167.0]') + + # No reverse division, however + with self.assertRaises(TypeError): + 2 / x + + # Any other type fails + with self.assertRaises(TypeError): + x / object() + with self.assertRaises(TypeError): + object() / x diff --git a/lib/cassowary/utils.py b/lib/cassowary/utils.py new file mode 100644 index 00000000..ac348305 --- /dev/null +++ b/lib/cassowary/utils.py @@ -0,0 +1,30 @@ +from __future__ import print_function, unicode_literals, absolute_import, division + + +EPSILON = 1e-8 + +REQUIRED = 1001001000 +STRONG = 1000000 +MEDIUM = 1000 +WEAK = 1 + + +def approx_equal(a, b, epsilon=EPSILON): + "A comparison mechanism for floats" + return abs(a - b) < epsilon + + +def repr_strength(strength): + """Convert a numerical strength constant into a human-readable value. + + We could wrap this up in an enum, but enums aren't available in Py2; + we could use a utility class, but we really don't need the extra + implementation weight. In practice, this repr is only used for debug + purposes during development. + """ + return { + REQUIRED: 'Required', + STRONG: 'Strong', + MEDIUM: 'Medium', + WEAK: 'Weak' + }[strength] diff --git a/lib/cassowary/utils.rb b/lib/cassowary/utils.rb new file mode 100644 index 00000000..60720a65 --- /dev/null +++ b/lib/cassowary/utils.rb @@ -0,0 +1,38 @@ + +module Cassowary + + # from utils.py + EPSILON = 1e-8 + + REQUIRED = 1001001000 + STRONG = 1000000 + MEDIUM = 1000 + WEAK = 1 + + def approx_equal(a, b, epsilon=EPSILON) + return (a - b).abs < epsilon + end + + # from error.py + class CassowaryException < StandardError; end + class InternalError < CassowaryException; end + class ConstraintNotFound < CassowaryException; end + class RequiredFailure < CassowaryException; end + class NotImplemented < CassowaryException; end + class ZeroDivisionError < CassowaryException; end + + #from edit_info.py + class EditInfo + attr_accessor :constraint, :edit_plus, :edit_minus, :prev_edit_constant, :index + + def initialize(constraint, edit_plus, edit_minus, prev_edit_constant, index) + @constraint = constraint + @edit_plus = edit_plus + @edit_minus = edit_minus + @prev_edit_constant = prev_edit_constant + @index = index + end + end + + +end diff --git a/lib/cassowary/variable.rb b/lib/cassowary/variable.rb new file mode 100644 index 00000000..24142259 --- /dev/null +++ b/lib/cassowary/variable.rb @@ -0,0 +1,163 @@ +########################################################################### +# Variables +# +# Variables are the atomic unit of linear programming, describing the +# quantities that are to be solved and constrained. +########################################################################### +module Cassowary + + class AbstractVariable + attr_accessor :name, :is_dummy, :is_external, :is_restricted + def initialize(name) + @name = name + @is_dummy = false + @is_external = false + @is_pivotable = false + @is_restricted = false + end + + ## TODO verify + # Var * x + def *(x) + if x.kind_of? Numeric + return Expression.new(self, x) + end + if x.kind_of? Expression + if x.is_constant + return Expression.new(self, x.constant) + else + raise NotImplmented + end + end + end + + # Var / x + def /(x) + if x.kind_of? Numeric + if approx_equal(x, 0) + raise ZeroDivisionError + end + return Expression.new(value = 1.0 / x.constant) + elsif x.kind_of? Expression + if x.is_constant + return Expression(value = 1.0 /x.contant) + else + return NotImplemented + end + else + return NotImplemented + end + end + + # Var + x + def +(x) + if x.kind_of? Numeric + return Expression.new(constant = x) + elsif x.kind_of? Expression + return Expression.new(self) + x + elsif x.kind_of? AbstractVariable + return Expression.new(self) + Expression.new(x) + else + return NotImplemented + end + end + + # Var - x + def -(x) + if x.kind_of? Numeric + return Expression.new(constant=-x) + elsif x.kind_of? Expression + return Expression.new(self) - x + elsif x.kind_of? AbstractVariable + return Expression.new(self) - Expression.new(x) + else + return NotImplemented + end + end + + end # class AbstractVariable + + class Variable < AbstractVariable + attr_accessor :value, :is_external + + def initialize(name, value=0.0) + super(name) + @value = value.to_f + @is_external = True + end + # __hash__ = object.__hash__ + + # Var == x # eq, NOT assignment! NOT standard? + def ==(other) + case other.kind_of? + when Expression, Variable, Numeric + return Constrait.new(self, Constraint.EQ, other) + else + return NotImplemented + end + end + + # Var < x + def <(other) + # < and <= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + return self <= other + end + + # Var <= x + def <=(other) + case other.kind_of? + when Expression, Variable, Numeric + return Constraint.new(self, Constraint.LEQ, other) + else + return NotImplemented + end + end + + # Var > x + def >(other) + # > and >= are equivalent in the API; it's effectively true + # due to float arithmetic, and it makes the API a little less hostile, + # because all the comparison operators exist. + return self >= other + end + + def >= (other) + case other.kind_of? + when Expression, Variable, Numeric + return Constraint.new(self, Constraint.GEQ, other) + else + return NotImplemented + end + end + + end # class Variable + + + class DummyVariable < AbstractVariable + attr_accessor :is_dummy, :is_restricted + + def initialize(number) + super(name = sprintf("d%s",number)) + @is_dummy = true + @is_restricted = true + end + end # class DummyVariable + + # what does this do other than a sugar layer + class ObjectiveVariable < AbstractVariable + def initialize(name) + super(name) + end + end # class ObjectiveVariable + + class SlackVariable < AbstractVariable + attr_accessor :is_pivotable, :is_restricted + def initialize(prefix, number) + super(name=sprintf("%s%s", prefix, number)) + @is_pivotable = true + @is_restricted = true + end + end # class SlackVariable +end # module diff --git a/shoes/canvas.c b/shoes/canvas.c index 5bb7cfe5..736f544d 100644 --- a/shoes/canvas.c +++ b/shoes/canvas.c @@ -211,9 +211,10 @@ void shoes_undo_transformation(cairo_t *cr, shoes_transform *st, shoes_place *pl VALUE shoes_add_ele(shoes_canvas *canvas, VALUE ele) { if (NIL_P(ele)) return ele; - if (canvas->insertion <= -1) + //if (canvas->insertion <= -1) { TODO TODO this should have worked !!! + if (canvas->insertion < 0) { rb_ary_push(canvas->contents, ele); - else { + } else { rb_ary_insert_at(canvas->contents, canvas->insertion, 0, ele); canvas->insertion++; } @@ -282,7 +283,7 @@ static void shoes_canvas_empty(shoes_canvas *canvas, int extras) { shoes_ele_remove_all(canvas->contents); if (extras) shoes_extras_remove_all(canvas); if (! NIL_P(canvas->layout_mgr)) { - shoes_layout_clear(canvas); + shoes_layout_cleared(canvas); } canvas->stage = stage; @@ -732,12 +733,6 @@ void shoes_canvas_insert(VALUE self, long i, VALUE ele, VALUE block) { if (rb_respond_to(block, s_widget)) rb_funcall(block, s_widget, 1, self); else { - if (! NIL_P(canvas->layout_mgr)) { - shoes_layout *lay; - Data_Get_Struct(canvas->layout_mgr, shoes_layout, lay); - fprintf(stderr, "Insert into Layout\n"); - - } else shoes_canvas_memdraw(self, block); } canvas->insertion = -2; diff --git a/shoes/canvas.h b/shoes/canvas.h index d6c0e910..78891009 100644 --- a/shoes/canvas.h +++ b/shoes/canvas.h @@ -318,6 +318,7 @@ VALUE shoes_canvas_reset(VALUE); VALUE shoes_canvas_contents(VALUE); VALUE shoes_canvas_children(VALUE); void shoes_canvas_size(VALUE, int, int); +void shoes_canvas_insert(VALUE self, long i, VALUE ele, VALUE block); VALUE shoes_canvas_clear_contents(int, VALUE *, VALUE); VALUE shoes_canvas_remove(VALUE); VALUE shoes_canvas_refresh_slot(VALUE); // 3.3.0 diff --git a/shoes/types/layout.c b/shoes/types/layout.c index 82723c61..ad61fd2a 100644 --- a/shoes/types/layout.c +++ b/shoes/types/layout.c @@ -26,15 +26,11 @@ FUNC_M("+layout", layout, -1); void shoes_layout_init() { cLayout = rb_define_class_under(cTypes, "Layout", cShoes); - //cLayout = rb_define_class_under(cFlow, "Layout", cShoes); - rb_define_method(cLayout, "append", CASTHOOK(shoes_layout_append), -1); // crash in shoes_canvas_memdraw - //rb_define_method(cLayout, "append", CASTHOOK(shoes_canvas_append), -1); // slot is being modified - rb_define_method(cLayout, "clear", CASTHOOK(shoes_canvas_clear_contents), -1); - rb_define_method(cLayout, "prepend", CASTHOOK(shoes_canvas_prepend), -1); - rb_define_method(cLayout, "before", CASTHOOK(shoes_canvas_before), -1); - rb_define_method(cLayout, "after", CASTHOOK(shoes_canvas_after), -1); + rb_define_method(cLayout, "clear", CASTHOOK(shoes_layout_clear), -1); + rb_define_method(cLayout, "insert", CASTHOOK(shoes_layout_insert), -1); + rb_define_method(cLayout, "delete_at", CASTHOOK(shoes_layout_delete_at), -1); rb_define_method(cLayout, "rule", CASTHOOK(shoes_layout_add_rule), -1); - rb_define_method(cLayout, "finish", CASTHOOK(shoes_layout_compute), -1); + rb_define_method(cLayout, "finish", CASTHOOK(shoes_layout_finish), -1); /* RUBY_M generates defines (allow Ruby to call the FUNC_M funtions rb_define_method(cCanvas, "layout", CASTHOOK(shoes_canvas_c_layout), -1); @@ -90,35 +86,55 @@ VALUE shoes_layout_new(VALUE attr, VALUE parent) { return obj; } -VALUE shoes_layout_append(int argc, VALUE *argv, VALUE self) { +VALUE shoes_layout_insert(int argc, VALUE *argv, VALUE self) { shoes_layout *lay; Data_Get_Struct(self, shoes_layout, lay); VALUE canvas = lay->canvas; rb_arg_list args; - rb_parse_args(argc, argv, "o,&", &args); - shoes_canvas_insert(canvas, -1, Qnil, args.a[0]); + rb_parse_args(argc, argv, "i&", &args); + long pos = NUM2LONG(args.a[0]); + shoes_canvas_insert(canvas, pos, Qnil, args.a[1]); return self; } -VALUE shoes_layout_prepend(VALUE self, VALUE ele) { - shoes_layout *ly; - Data_Get_Struct(self, shoes_layout, ly); - shoes_canvas *canvas; - Data_Get_Struct(ly->canvas, shoes_canvas, canvas); - shoes_layout_add_ele(canvas, ele); +VALUE shoes_layout_delete_at(int argc, VALUE *argv, VALUE self) { + shoes_layout *lay; + Data_Get_Struct(self, shoes_layout, lay); + VALUE canvas_obj = lay->canvas; + shoes_canvas *canvas; + Data_Get_Struct(canvas_obj, shoes_canvas, canvas); + rb_arg_list args; + rb_parse_args(argc, argv, "i", &args); + long pos = NUM2LONG(args.a[0]); + VALUE ele = rb_ary_entry(canvas->contents, pos); + ID s_rm = rb_intern("remove"); + if (rb_respond_to(ele, s_rm)) + rb_funcall(ele, s_rm, 0, Qnil); + return self; +} + +VALUE shoes_layout_clear(int argc, VALUE *argv, VALUE self) { + shoes_layout *lay; + Data_Get_Struct(self, shoes_layout, lay); + VALUE canvas = lay->canvas; + rb_arg_list args; + shoes_canvas_clear_contents(argc, argv, canvas); } // called from shoes_add_ele (def in canvas.c) by widget creators // The ele has already been added to canvas->contents void shoes_layout_add_ele(shoes_canvas *canvas, VALUE ele) { if (rb_obj_is_kind_of(ele, cBackground)) { - //fprintf(stderr, "skipping background widget\n"); + fprintf(stderr, "skipping background widget\n"); return; } // Find a delegate or use the internal default? if (canvas->layout_mgr != Qnil) { shoes_layout *ly; Data_Get_Struct(canvas->layout_mgr, shoes_layout, ly); + // for debug + shoes_canvas *cvs; + Data_Get_Struct(ly->canvas, shoes_canvas, cvs); if (! NIL_P(ly->delegate)) { //printf(stderr,"Delegating\n"); VALUE del = ly->delegate; @@ -137,7 +153,7 @@ void shoes_layout_add_ele(shoes_canvas *canvas, VALUE ele) { } // called from inside shoes (shoes_canvas_clear) -void shoes_layout_clear(shoes_canvas *canvas) { +void shoes_layout_cleared(shoes_canvas *canvas) { fprintf(stderr,"shoes_layout_clear called\n"); if (canvas->layout_mgr != Qnil) { shoes_layout *ly; @@ -171,10 +187,22 @@ void shoes_layout_default_clear(shoes_canvas *canvas) { fprintf(stderr, "default layout clear\n"); } -VALUE shoes_layout_add_rule(VALUE self, VALUE rule) { +VALUE shoes_layout_add_rule(int argc, VALUE *argv, VALUE self) { + shoes_layout *lay; + Data_Get_Struct(self, shoes_layout, lay); + VALUE cobj = lay->canvas; + shoes_canvas *canvas; + Data_Get_Struct(cobj, shoes_canvas, canvas); + fprintf(stderr,"shoes_layout_add_rule called\n"); } -VALUE shoes_layout_compute(VALUE self) { +VALUE shoes_layout_finish(int argc, VALUE *argv, VALUE self) { + shoes_layout *lay; + Data_Get_Struct(self, shoes_layout, lay); + VALUE cobj = lay->canvas; + shoes_canvas *canvas; + Data_Get_Struct(cobj, shoes_canvas, canvas); + fprintf(stderr, "shoes_layout_compute called\n"); } diff --git a/shoes/types/layout.h b/shoes/types/layout.h index d3d7132b..9e6cfa1f 100644 --- a/shoes/types/layout.h +++ b/shoes/types/layout.h @@ -3,9 +3,9 @@ typedef struct { VALUE delegate; - VALUE canvas; // TODO is this used? + VALUE canvas; // fields below belong to the C crafted default layout manager, what ever that - // is. + // turns out to be. int x; int y; } shoes_layout; @@ -14,16 +14,19 @@ extern VALUE cLayout; void shoes_layout_init(); VALUE shoes_layout_new(VALUE attr, VALUE parent); // slot like methods: -VALUE shoes_layout_append(int argc, VALUE *argv, VALUE self); -VALUE shoes_layout_prepend(VALUE self, VALUE ele); -VALUE shoes_layout_add_rule(VALUE self, VALUE rule); -VALUE shoes_layout_compute(VALUE self); +VALUE shoes_layout_insert(int argc, VALUE *argv, VALUE self); +VALUE shoes_layout_delete_at(int argc, VALUE *argv, VALUE self); +VALUE shoes_layout_clear(int argc, VALUE *argv, VALUE self); +VALUE shoes_layout_refresh(int argc, VALUE *argv, VALUE self); +VALUE shoes_layout_add_rule(int argc, VALUE *argv, VALUE self); +VALUE shoes_layout_finish(int argc, VALUE *argv, VALUE self); +// canvas calls these, delegate to usr or the secret layout +void shoes_layout_cleared(shoes_canvas *canvas); void shoes_layout_add_ele(shoes_canvas *canvas, VALUE ele); -void shoes_layout_clear(shoes_canvas *canvas); VALUE shoes_layout_delete_ele(shoes_canvas *canvas, VALUE ele); -// methods for the default manager. TODO: Write it. +// TODO: delagate methods for the secret default manager. void shoes_layout_default_add(shoes_canvas *canvas, VALUE ele); void shoes_layout_default_clear(shoes_canvas *canvas); #endif