-
Notifications
You must be signed in to change notification settings - Fork 0
/
symutil.py
333 lines (261 loc) · 11.7 KB
/
symutil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Utilities for treating symbolic expressions.
Created on Wed Nov 1 14:46:34 2017
@author: Juha Jeronen <[email protected]>
"""
import sympy as sy
from sympy.core.function import UndefinedFunction
from util import name_derivative, degreek
def make_function(name, *deps):
"""Create an unspecified function with known dependencies.
(Convenience function.)
Parameters:
name: str
Name of the function to define.
*deps: sy.Symbol
Either an independent variable (bare symbol, e.g. sy.symbols("x")),
or to create a layer cake, a symbol previously returned by
``make_function()`` itself.
Returns:
sy.Function
The applied function (see below).
In SymPy, unspecified functions are set up in two steps:
1) An "undefined function": a symbol for a generic unknown function
having the given symbol name.
Each function must have a unique symbol name. SymPy distinguishes
between symbols by symbol name, flags and Python object type
(e.g. sy.Symbol vs. sy.Function).
2) An "applied function" (of symbols). Calling an UndefinedFunction
instance, with symbols as parameters, returns an otherwise
unspecified function that formally depends on the given symbols.
Importantly, these dependencies are recognized in symbolic
differentiation.
SymPy creates a new Python type (class) for each function name,
using the symbol name of the undefined-function instance (that
was used to create the applied-function instance) as the name
of the new Python type.
Example, raw SymPy:
import sympy as sy
x, y = sy.symbols("x, y") # independent variables
λf = sy.symbols("f", cls=sy.Function) # undefined function
f = λf(x, y) # applied function
Using ``make_function()``:
import sympy as sy
x, y = sy.symbols("x, y")
f = make_function("f", x, y)
In both cases:
type(f) # --> f
"""
λf = sy.symbols(name, cls=sy.Function)
return λf(*deps)
def sortkey(sym):
"""Sort key for ``sy.Symbol`` objects.
Returns:
The string representation of ``sym``, lowercased.
"""
return str(sym).lower()
def nameof(sym):
"""Return the name of ``sym`` as str."""
if hasattr(sym, "name"):
return sym.name
else: # e.g. an undefined function has no name, but its *class* has a __name__.
return sym.__class__.__name__
def nameof_as_symbol(sym):
"""Return a new ``sy.Symbol`` that has the same symbol name as ``sym``.
Useful to unify handling of ``Symbol`` and ``UndefinedFunction`` objects
in use cases that need only the symbol name, since these datatypes store
their symbol names differently.
Assumptions are copied from sym.
**Danger**: if sym is an ``UndefinedFunction``, the returned ``Symbol``
is marked with default assumptions (as of SymPy 1.0, ``commutative=True``).
Obviously, a ``Symbol`` cannot represent an ``UndefinedFunction``
completely faithfully, as these datatypes are *intended* to have
different behavior.
"""
return sy.symbols(nameof(sym), **sym.assumptions0)
def strip_function_arguments(expr):
"""Strip argument lists from unknown functions in ``expr``.
The stripping is applied recursively.
In the output, each ``UndefinedFunction`` is replaced by a bare ``Symbol``
having the same symbol name.
Mainly useful for printing, when there are several layers of dependencies
and one takes partial derivatives utilizing the chain rule. This often
leads to kilometer-long argument lists that make an unstripped printout
of ``expr`` unreadable for humans.
"""
# We cannot use apply_to_instancesof_in() to implement this, since each
# undefined function is an instance of its own Python type, and *this type*
# is an instance of UndefinedFunction.
#
# This pattern is specific to UndefinedFunctions, so we implement manually.
if isinstance(expr.__class__, UndefinedFunction):
return nameof_as_symbol(expr) # don't bother recursing into args since they get deleted here
elif expr.is_Atom:
return expr
else: # compound other than an undefined function
out = [strip_function_arguments(x) for x in expr.args]
cls = type(expr)
return cls(*out)
def canonize_derivative(expr):
"""Sort the varlist (diff w.r.t. what) in a derivative.
Useful for higher derivatives of C^k functions.
Parameters:
expr: sy.Derivative
Returns:
expr, with the variables w.r.t. which the derivative is taken,
sorted in canonical order using ``symutil.sortkey()``.
"""
if not isinstance(expr, sy.Derivative):
raise TypeError("Expected Derivative, got {} {}".format(type(expr), expr))
f, *vs = expr.args
out = [f]
out.extend(sorted(vs, key=sortkey))
cls = type(expr)
return cls(*out, evaluate=False)
def derivatives_needed_by(expr, canonize=True):
"""Return a list describing derivatives ``expr`` needs.
This works by matching unevaluated ``Derivative`` symbols in ``expr``, recursively.
Parameters:
expr: sy.Expr
The expression.
canonize: bool
If True, the varlist (diff w.r.t. what) of each derivative will be sorted.
(Useful for higher derivatives of C^k functions.)
If False, the varlist is passed through as-is.
Returns: tuple
containing tuples of symbols (f, x1, x2, ..., xn), where:
f: the function that is differentiated. This is the original symbol
from ``expr``.
x1, x2, ..., xn: variables with respect to which f is differentiated.
Higher derivatives are represented by repeating
the same symbol, e.g. ∂²f(x)/∂x² -> (f,x,x).
"""
maybe_canonize = canonize_derivative if canonize else lambda x: x
derivatives = set()
def process(e):
if isinstance(e, sy.Derivative):
# args[0] = function, args[1:] = diff. w.r.t. what
if not e.args[0].is_Number: # ignore nonsense like "d(0)/dx"
derivatives.add(maybe_canonize(e).args)
elif not e.is_Atom: # compound other than a derivative
for x in e.args:
process(x)
process(expr)
# sort the derivatives (passing the final varlists through as-is)
return sorted(derivatives, key=lambda item: [sortkey(x) for x in item])
def map_instancesof_in(func, cls, expr):
"""Apply ``func`` to instances of given ``cls`` in ``expr``, recursively.
If you need to adapt a different call signature, or call members of
``Expr``, use a helper function:
def apply_helper(expr):
expr.some_member_function()
result = map_instancesof_in(apply_helper, whatever, whatever)
Parameters:
func: function ``Expr`` -> ``Expr``
Each subexpr that matches ``cls`` will be replaced by ``func(subexpr)``.
Note that the output may be any subclass of Expr, i.e. type changes
are also allowed. The Expr output type is because the output, just
like the input, must be a valid node for a SymPy expression tree.
cls: type, or tuple of types
where type is a SymPy expression type such as ``Add``, ``Mul``, ``Subs``, ...
expr: SymPy expression object
Returns:
Processed ``Expr``, where ``func`` has been applied to each instance
of ``cls`` in ``expr``.
"""
# Atoms do not have args, but their constructors have positional arguments
# depending on the specific type, so we must process them separately.
if expr.is_Atom:
return func(expr) if isinstance(expr, cls) else expr
else:
# note order of processing: we must do args first, then expr itself
out = (map_instancesof_in(func, cls, x) for x in expr.args)
expr_cls = type(expr)
# TODO: do we need to copy also something other than assumptions here?
tmp = expr_cls(*out, **expr.assumptions0)
return func(tmp) if isinstance(tmp, cls) else tmp
def collect_const_in(expr):
"""Collect constant factors in sums nested inside ``expr``."""
return map_instancesof_in(sy.collect_const, sy.Add, expr)
def apply_substitutions_in(expr):
"""Apply unevaluated substitutions nested inside ``expr``."""
def doit_func(expr):
return expr.doit()
return map_instancesof_in(doit_func, sy.Subs, expr)
def derivatives_to_names_in(expr, as_fortran_identifier=False):
"""Rename derivative objects in ``expr``, recursively.
Derivatives in ``expr`` will be replaced by bare symbols. The symbols are named
using a naming scheme that depends on the option ``as_fortran_identifier``.
Parameters:
expr: sy.Expr
The expression to process.
as_fortran_identifier: bool
If False, the generated symbol names are a Unicode representation
of standard mathematical notation, e.g. ∂f/∂x.
If True, the generated symbol names are sanitized for use as
Fortran identifiers, e.g. df_dx. Note that this only sanitizes
the derivative notation, and especially, does **not** remove
Greek characters; for that, see ``util.degreek()``.
Returns:
sy.Expr:
The processed expression.
"""
def rename(expr):
expr = strip_function_arguments(expr)
fname, *vnames = (str(arg) for arg in expr.args)
# we must return an Expr, so wrap the identifier in a Symbol
return sy.symbols(name_derivative(fname, vnames, as_fortran_identifier=as_fortran_identifier))
return map_instancesof_in(rename, sy.Derivative, expr)
def is_symmetric(mat):
"""Return whether a sy.Matrix is symmetric."""
n, nc = mat.shape
return nc == n and all(mat[j,i] == mat[i,j] for i in range(n) for j in range(i+1, n))
def degreek_in(expr, short=True):
"""Remove Greek letters in ``expr``, recursively.
Delegates to ``util.degreek`` for each encountered Symbol.
Does not rename function symbols! (Strip their arguments first to convert
them to bare symbols.)
"""
def rename(sym):
return sy.symbols(degreek(nameof(sym), short=short), **sym.assumptions0)
return map_instancesof_in(rename, sy.Symbol, expr)
def voigt_mat_idx():
"""Return index conversion table between Voigt-packed vector and matrix.
For symmetric rank-2 tensors in 3D space.
Returns:
((k, (r, c)), ...)
where
k: index in Voigt-packed vector (0-based)
r: row in matrix (0-based)
c: column in matrix (0-based)
The matrix is symmetric; (r, c) pairs are given for the upper triangle only.
Example (math notation):
ε_voigt = [εxx,εyy,εzz,εyz,εzx,εxy] = [ε1,ε2,ε3,ε4,ε5,ε6]
ε_mat = [[ε1, ε6, ε5]
[ε6, ε2, ε4]
[ε5, ε4, ε3]]
"""
return ((0, (0, 0)),
(1, (1, 1)),
(2, (2, 2)),
(3, (1, 2)),
(4, (0, 2)),
(5, (0, 1)))
def voigt_to_mat(vec):
"""Convert Voigt-packed vector to matrix.
For symmetric rank-2 tensors in 3D space.
Parameters:
vec: indexable, length 6
See ``voigt_mat_idx()`` for which component is which.
Returns:
sy.Matrix, shape (3, 3)
The matrix representation of vec. Guaranteed to be symmetric.
"""
if len(vec) != 6:
raise ValueError("vec should have length 6, got {invalid}".format(invalid=len(vec)))
mat = sy.Matrix.zeros(3, 3)
for k, (r, c) in voigt_mat_idx():
mat[r, c] = mat[c, r] = vec[k]
assert is_symmetric(mat)
return mat