Skip to content

Commit

Permalink
Merge pull request MLton#586 from MatthewFluet/useless-bugfix
Browse files Browse the repository at this point in the history
Treat length of sequences as a "slot" in Useless optimization
  • Loading branch information
MatthewFluet authored Dec 20, 2024
2 parents bf18753 + a4b433f commit 665812b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 49 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ Here are the changes from version 20210117 to YYYYMMDD.

=== Details

* 2024-12-20
** Fix bug in the optimization of representations of sequences in
`Useless` SSA optimization. Thanks to Humza Shahid (hummy123) for
the bug report.

* 2024-12-14
** Update SML/NJ libraries to SML/NJ 110.99.6.1.

Expand Down
109 changes: 60 additions & 49 deletions mlton/ssa/useless.fun
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ structure Value =
value: value} Set.t
and value =
Array of {elt: slot,
length: t,
length: slot,
useful: Useful.t}
| Ground of Useful.t
| Ref of {arg: slot,
useful: Useful.t}
| Tuple of slot vector
| Vector of {elt: slot,
length: t}
length: slot}
| Weak of {arg: slot,
useful: Useful.t}
withtype slot = t * Exists.t
Expand All @@ -135,7 +135,7 @@ structure Value =
Array {elt, length, useful} =>
seq [str "array ",
record [("useful", Useful.layout useful),
("length", layout length),
("length", layoutSlot length),
("elt", layoutSlot elt)]]
| Ground g => seq [str "ground ", Useful.layout g]
| Ref {arg, useful, ...} =>
Expand All @@ -145,7 +145,7 @@ structure Value =
| Tuple ss => Vector.layout layoutSlot ss
| Vector {elt, length} =>
seq [str "vector ",
record [("length", layout length),
record [("length", layoutSlot length),
("elt", layoutSlot elt)]]
| Weak {arg, useful} =>
seq [str "weak ",
Expand All @@ -167,7 +167,9 @@ structure Value =
case (v, v') of
(Array {useful = u, length = n, elt = e},
Array {useful = u', length = n', elt = e'}) =>
(Useful.== (u, u'); unify (n, n'); unifySlot (e, e'))
(Useful.== (u, u')
; unifySlot (n, n')
; unifySlot (e, e'))
| (Ground g, Ground g') => Useful.== (g, g')
| (Ref {useful = u, arg = a},
Ref {useful = u', arg = a'}) =>
Expand All @@ -176,7 +178,8 @@ structure Value =
Vector.foreach2 (ss, ss', unifySlot)
| (Vector {length = n, elt = e},
Vector {length = n', elt = e'}) =>
(unify (n, n'); unifySlot (e, e'))
(unifySlot (n, n')
; unifySlot (e, e'))
| (Weak {useful = u, arg = a}, Weak {useful = u', arg = a'}) =>
(Useful.== (u, u'); unifySlot (a, a'))
| _ => Error.bug "Useless.Value.unify: strange"
Expand All @@ -202,7 +205,7 @@ structure Value =
coerceSlot {from = s, to = s'})
| (Vector {length = n, elt = e},
Vector {length = n', elt = e'}) =>
(coerce {from = n, to = n'}
(coerceSlot {from = n, to = n'}
; coerceSlot {from = e, to = e'})
| (Weak _, Weak _) => unify (from, to)
| _ => Error.bug "Useless.Value.coerce: strange"
Expand All @@ -219,6 +222,15 @@ structure Value =
Unit.layout)
coerce

val coerceSlot =
Trace.trace ("Useless.Value.coerceSlot",
fn {from, to} => let open Layout
in record [("from", layoutSlot from),
("to", layoutSlot to)]
end,
Unit.layout)
coerceSlot

fun coerces {from, to} =
Vector.foreach2 (from, to, fn (from, to) =>
coerce {from = from, to = to})
Expand All @@ -228,11 +240,11 @@ structure Value =
fun loop (v: t): unit =
case value v of
Array {length, elt, useful} =>
(f useful; loop length; slot elt)
(f useful; slot length; slot elt)
| Ground u => f u
| Tuple ss => Vector.foreach (ss, slot)
| Ref {arg, useful} => (f useful; slot arg)
| Vector {length, elt} => (loop length; slot elt)
| Vector {length, elt} => (slot length; slot elt)
| Weak {arg, useful} => (f useful; slot arg)
and slot (v, _) = loop v
in
Expand All @@ -258,7 +270,7 @@ structure Value =
| Ground u => SOME u
| Ref {useful = u, ...} => SOME u
| Tuple ss => Vector.peekMap (ss, someUseful o #1)
| Vector {length, ...} => SOME (deground length)
| Vector {length, ...} => SOME (deground (#1 length))
| Weak {useful = u, ...} => SOME u

fun allOrNothing (v: t): Useful.t option =
Expand All @@ -281,18 +293,17 @@ structure Value =
let val e = Exists.new ()
in (loop (t, e :: es), e)
end
val loop = fn t => loop (t, es)
val value =
case Type.dest t of
Type.Array t =>
Array {useful = useful (),
length = loop (Type.word (WordSize.seqIndex ())),
length = slot (Type.word (WordSize.seqIndex ())),
elt = slot t}
| Type.Ref t => Ref {arg = slot t,
useful = useful ()}
| Type.Tuple ts => Tuple (Vector.map (ts, slot))
| Type.Vector t =>
Vector {length = loop (Type.word (WordSize.seqIndex ())),
Vector {length = slot (Type.word (WordSize.seqIndex ())),
elt = slot t}
| Type.Weak t => Weak {arg = slot t,
useful = useful ()}
Expand Down Expand Up @@ -379,7 +390,7 @@ structure Value =
struct
datatype t =
Array of (Type.t * bool)
| Length
| Length of bool
| LengthRef
| UnitRef
| Unit
Expand All @@ -388,7 +399,7 @@ structure Value =
struct
datatype t =
Vector of (Type.t * bool)
| Length
| Length of bool
| Unit
end

Expand All @@ -399,28 +410,28 @@ structure Value =
in (if Exists.doesExist e then SOME t else NONE, b)
end
and arrayRep {elt, length, useful}: ArrayRep.t =
(case (getNewSlot elt, isUseful length, Useful.isUseful useful) of
((SOME ty, eltUseful), lengthUseful, useful) =>
(case (getNewSlot elt, getNewSlot length, Useful.isUseful useful) of
((SOME ty, eltUseful), (_, lengthUseful), useful) =>
ArrayRep.Array (ty, eltUseful orelse lengthUseful orelse useful)
| ((NONE, false), true, false) => ArrayRep.Length
| ((NONE, false), true, true) => ArrayRep.LengthRef
| ((NONE, false), false, true) => ArrayRep.UnitRef
| ((NONE, false), false, false) => ArrayRep.Unit
| ((NONE, false), (SOME _, _), true) => ArrayRep.LengthRef
| ((NONE, false), (SOME _, lengthUseful), false) => ArrayRep.Length lengthUseful
| ((NONE, false), (NONE, false), true) => ArrayRep.UnitRef
| ((NONE, false), (NONE, false), false) => ArrayRep.Unit
| _ => Error.bug (concat
["Value.arrayRep: ",
"elt: ", Layout.toString (layoutSlot elt), "; ",
"length: ", Layout.toString (layout length), "; ",
"length: ", Layout.toString (layoutSlot length), "; ",
"useful: ", Layout.toString (Useful.layout useful)]))
and vectorRep {elt, length}: VectorRep.t =
(case (getNewSlot elt, isUseful length) of
((SOME ty, eltUseful), lengthUseful) =>
(case (getNewSlot elt, getNewSlot length) of
((SOME ty, eltUseful), (_, lengthUseful)) =>
VectorRep.Vector (ty, eltUseful orelse lengthUseful)
| ((NONE, false), true) => VectorRep.Length
| ((NONE, false), false) => VectorRep.Unit
| ((NONE, false), (SOME _, lengthUseful)) => VectorRep.Length lengthUseful
| ((NONE, false), (NONE, false)) => VectorRep.Unit
| _ => Error.bug (concat
["Value.vectorRep: ",
"elt: ", Layout.toString (layoutSlot elt), "; ",
"length: ", Layout.toString (layout length)]))
"length: ", Layout.toString (layoutSlot length)]))
and getNew (T s): Type.t * bool =
let
val {value, ty, new, ...} = Set.! s
Expand All @@ -439,7 +450,7 @@ structure Value =
Array arg =>
(case arrayRep arg of
ArrayRep.Array (ty, u) => (Type.array ty, u)
| ArrayRep.Length => (Type.word (WordSize.seqIndex ()), true)
| ArrayRep.Length u => (Type.word (WordSize.seqIndex ()), u)
| ArrayRep.LengthRef => (Type.reff (Type.word (WordSize.seqIndex ())), true)
| ArrayRep.UnitRef => (Type.reff Type.unit, true)
| ArrayRep.Unit => (Type.unit, false))
Expand All @@ -461,7 +472,7 @@ structure Value =
| Vector arg =>
(case vectorRep arg of
VectorRep.Vector (ty, u) => (Type.vector ty, u)
| VectorRep.Length => (Type.word (WordSize.seqIndex ()), true)
| VectorRep.Length u => (Type.word (WordSize.seqIndex ()), u)
| VectorRep.Unit => (Type.unit, false))
| Weak {arg, useful} =>
orU (wrap (arg, Type.weak), useful)
Expand Down Expand Up @@ -550,7 +561,7 @@ fun transform (program: Program.t): Program.t =
case value v of
Array {useful = u, length = n, elt = e} =>
(Useful.makeUseful u
; deepMakeUseful n
; slot n
; slot e)
| Ground u =>
(Useful.makeUseful u
Expand All @@ -574,7 +585,7 @@ fun transform (program: Program.t): Program.t =
| _ => ()))
| Ref {useful = u, arg = a} => (Useful.makeUseful u; slot a)
| Tuple vs => Vector.foreach (vs, slot)
| Vector {length = n, elt = e} => (deepMakeUseful n; slot e)
| Vector {length = n, elt = e} => (slot n; slot e)
| Weak {useful = u, arg = a} => (Useful.makeUseful u; slot a)
end
val deepMakeUseful =
Expand Down Expand Up @@ -614,7 +625,7 @@ fun transform (program: Program.t): Program.t =
| _ => ()))
| Ref {useful = u, ...} => Useful.makeUseful u
| Tuple vs => Vector.foreach (vs, slot)
| Vector {length = n, elt = e} => (shallowMakeUseful n; slot e)
| Vector {length = n, elt = e} => (slot n; slot e)
| Weak {useful = u, ...} => Useful.makeUseful u
end
val shallowMakeUseful =
Expand All @@ -637,7 +648,7 @@ fun transform (program: Program.t): Program.t =
| Ground u => Useful.makeWanted u
| Ref {useful = u, ...} => Useful.makeWanted u
| Tuple vs => Vector.foreach (vs, slot)
| Vector {length = n, elt = e} => (makeWanted n; slot e)
| Vector {length = n, elt = e} => (slot n; slot e)
| Weak {useful = u, ...} => Useful.makeWanted u
end
val makeWanted =
Expand Down Expand Up @@ -681,7 +692,7 @@ fun transform (program: Program.t): Program.t =
val (l, e) = arrayLengthAndElt (arg 0)
val (l', e') = seqLengthAndElt result
in
coerce {from = l, to = l'}
coerceSlot {from = l, to = l'}
; coerceSlot {from = e, to = e'}
end
fun update () =
Expand All @@ -699,7 +710,7 @@ fun transform (program: Program.t): Program.t =
| Prim.Array_array => seq arrayElt
| Prim.Array_copyArray => copy arrayEltSlot
| Prim.Array_copyVector => copy vectorEltSlot
| Prim.Array_length => length arrayLength
| Prim.Array_length => length (#1 o arrayLength)
| Prim.Array_sub => sub arrayElt
| Prim.Array_toArray => toSeq arrayLengthAndElt
| Prim.Array_toVector => toSeq vectorLengthAndElt
Expand Down Expand Up @@ -740,7 +751,7 @@ fun transform (program: Program.t): Program.t =
| Prim.Ref_assign => coerce {from = arg 1, to = deref (arg 0)}
| Prim.Ref_deref => return (deref (arg 0))
| Prim.Ref_ref => coerce {from = arg 0, to = deref result}
| Prim.Vector_length => length vectorLength
| Prim.Vector_length => length (#1 o vectorLength)
| Prim.Vector_sub => sub vectorElt
| Prim.Vector_vector => seq vectorElt
| Prim.Weak_canGet =>
Expand Down Expand Up @@ -976,7 +987,7 @@ fun transform (program: Program.t): Program.t =
Prim.Array_alloc _ =>
(case Value.arrayRep (Value.arrayArg resultValue) of
Value.ArrayRep.Array _ => doit ()
| Value.ArrayRep.Length => simple (Var (arg 0))
| Value.ArrayRep.Length _ => simple (Var (arg 0))
| Value.ArrayRep.LengthRef =>
simple (PrimApp {prim = Prim.Ref_ref,
targs = Vector.new1 (Type.word (WordSize.seqIndex ())),
Expand All @@ -989,7 +1000,7 @@ fun transform (program: Program.t): Program.t =
| Prim.Array_array =>
(case Value.arrayRep (Value.arrayArg resultValue) of
Value.ArrayRep.Array (eltTy, _) => makeSeq eltTy
| Value.ArrayRep.Length =>
| Value.ArrayRep.Length _ =>
let
val len_var = Var.newNoname ()
val len_ty = Type.word (WordSize.seqIndex ())
Expand Down Expand Up @@ -1029,7 +1040,7 @@ fun transform (program: Program.t): Program.t =
| Prim.Array_length =>
(case Value.arrayRep (Value.arrayArg (value (arg 0))) of
Value.ArrayRep.Array _ => doit ()
| Value.ArrayRep.Length => simple (Var (arg 0))
| Value.ArrayRep.Length _ => simple (Var (arg 0))
| Value.ArrayRep.LengthRef =>
simple (PrimApp {prim = Prim.Ref_deref,
targs = Vector.new1 (Type.word (WordSize.seqIndex ())),
Expand Down Expand Up @@ -1071,19 +1082,19 @@ fun transform (program: Program.t): Program.t =
in
Vector.new2 (len_stmt, len_ref_stmt)
end
| (Value.ArrayRep.Length, Value.ArrayRep.LengthRef) =>
| (Value.ArrayRep.Length _, Value.ArrayRep.LengthRef) =>
simple (PrimApp {prim = Prim.Ref_ref,
targs = Vector.new1 (Type.word (WordSize.seqIndex ())),
args = Vector.new1 (arg 0)})
| (Value.ArrayRep.LengthRef, Value.ArrayRep.LengthRef) =>
simple (Var (arg 0))
| (Value.ArrayRep.Array _, Value.ArrayRep.Length) =>
| (Value.ArrayRep.Array _, Value.ArrayRep.Length _) =>
simple (PrimApp {prim = Prim.Array_length,
targs = targs,
args = args})
| (Value.ArrayRep.Length, Value.ArrayRep.Length) =>
| (Value.ArrayRep.Length _, Value.ArrayRep.Length _) =>
simple (Var (arg 0))
| (Value.ArrayRep.LengthRef, Value.ArrayRep.Length) =>
| (Value.ArrayRep.LengthRef, Value.ArrayRep.Length _) =>
simple (PrimApp {prim = Prim.Ref_deref,
targs = Vector.new1 (Type.word (WordSize.seqIndex ())),
args = args})
Expand All @@ -1093,13 +1104,13 @@ fun transform (program: Program.t): Program.t =
(case (Value.arrayRep (Value.arrayArg (value (arg 0))),
Value.vectorRep (Value.vectorArg resultValue)) of
(_, Value.VectorRep.Unit) => simple (Var unitVar)
| (Value.ArrayRep.Array _, Value.VectorRep.Length) =>
| (Value.ArrayRep.Array _, Value.VectorRep.Length _) =>
simple (PrimApp {prim = Prim.Array_length,
targs = targs,
args = args})
| (Value.ArrayRep.Length, Value.VectorRep.Length) =>
| (Value.ArrayRep.Length _, Value.VectorRep.Length _) =>
simple (Var (arg 0))
| (Value.ArrayRep.LengthRef, Value.VectorRep.Length) =>
| (Value.ArrayRep.LengthRef, Value.VectorRep.Length _) =>
simple (PrimApp {prim = Prim.Ref_deref,
targs = Vector.new1 (Type.word (WordSize.seqIndex ())),
args = args})
Expand Down Expand Up @@ -1135,13 +1146,13 @@ fun transform (program: Program.t): Program.t =
| Prim.Vector_length =>
(case Value.vectorRep (Value.vectorArg (value (arg 0))) of
Value.VectorRep.Vector _ => doit ()
| Value.VectorRep.Length => simple (Var (arg 0))
| Value.VectorRep.Length _ => simple (Var (arg 0))
| Value.VectorRep.Unit =>
Error.bug "Useless.doitPrim: Vector_length/VectorRep.Unit")
| Prim.Vector_vector =>
(case Value.vectorRep (Value.vectorArg resultValue) of
Value.VectorRep.Vector (eltTy, _) => makeSeq eltTy
| Value.VectorRep.Length =>
| Value.VectorRep.Length _ =>
let
val len_var = Var.newNoname ()
val len_ty = Type.word (WordSize.seqIndex ())
Expand Down Expand Up @@ -1193,7 +1204,7 @@ fun transform (program: Program.t): Program.t =
targs = Vector.new1 Type.unit,
args = WordXVector.toVectorMap (ws, fn _ => unitVar)})
else simple e
| Value.VectorRep.Length =>
| Value.VectorRep.Length _ =>
simple (Const (Const.word
(WordX.fromInt
(WordXVector.length ws,
Expand Down

0 comments on commit 665812b

Please sign in to comment.