import networkx as nx
from openmdao.main.expreval import ConnectedExprEvaluator
from openmdao.main.pseudocomp import PseudoComponent
class ExprMapper(object):
[docs] """A mapping between source expressions and destination expressions"""
def __init__(self, scope):
self._exprgraph = nx.DiGraph() # graph of source expressions to destination expressions
self._scope = scope
def get_output_exprs(self):
[docs] """Return all destination expressions at the output boundary"""
exprs = []
graph = self._exprgraph
for node, data in graph.nodes(data=True):
if graph.in_degree(node) > 0:
expr = data['expr']
if len(expr.get_referenced_compnames()) == 0:
exprs.append(expr)
return exprs
def get_expr(self, text):
[docs] node = self._exprgraph.node.get(text)
if node:
return node['expr']
return None
def list_connections(self, show_passthrough=True, visible_only=False):
[docs] """Return a list of tuples of the form (outvarname, invarname).
"""
excludes = set([name for name, data in self._exprgraph.nodes(data=True)
if data['expr'].refs_parent()])
lst = [(u,v,data) for u,v,data in self._exprgraph.edges(data=True) if not (u in excludes or v in excludes)]
if not show_passthrough:
lst = [(u, v, data) for u, v, data in lst if '.' in u and '.' in v]
if visible_only:
newlst = []
for u, v, data in lst:
pcomp = data.get('pcomp')
if pcomp is not None:
newlst.extend(pcomp.list_connections(is_hidden=True))
else:
srccmp = getattr(self._scope, u.split('.',1)[0], None)
dstcmp = getattr(self._scope, v.split('.',1)[0], None)
if isinstance(srccmp, PseudoComponent) or isinstance(dstcmp, PseudoComponent):
continue
newlst.append((u,v))
return newlst
return [(u, v) for u, v, data in lst]
def get_source(self, dest_expr):
[docs] """Returns the text of the source expression that is connected to the given
destination expression.
"""
dct = self._exprgraph.pred.get(dest_expr)
if dct:
return dct.keys()[0]
else:
return None
def get_dests(self, src_expr):
[docs] """Returns the list of destination expressions that are connected to the given
source expression.
"""
graph = self._exprgraph
return [graph.node(name)['expr'] for name in self._exprgraph.succ[src_expr].keys()]
def remove(self, compname):
[docs] """Remove any connections referring to the given component"""
refs = self.find_referring_exprs(compname)
if refs:
self._exprgraph.remove_nodes_from(refs)
self._remove_disconnected_exprs()
def connect(self, srcexpr, destexpr, scope, pseudocomp=None):
[docs] src = srcexpr.text
dest = destexpr.text
srcvars = srcexpr.get_referenced_varpaths(copy=False)
destvar = destexpr.get_referenced_varpaths().pop()
destcompname, destcomp, destvarname = scope._split_varpath(destvar)
desttrait = None
srccomp = None
if not isinstance(destcomp, PseudoComponent) and not destvar.startswith('parent.') and not len(srcvars)>1:
for srcvar in srcvars:
if not srcvar.startswith('parent.'):
srccompname, srccomp, srcvarname = scope._split_varpath(srcvar)
if not isinstance(srccomp, PseudoComponent):
src_io = 'in' if srccomp is scope else 'out'
srccomp.get_dyn_trait(srcvarname, src_io)
if desttrait is None:
dest_io = 'out' if destcomp is scope else 'in'
desttrait = destcomp.get_dyn_trait(destvarname, dest_io)
if not isinstance(srccomp, PseudoComponent) and not srcexpr.refs_parent() and desttrait is not None:
# punt if dest is not just a simple var name.
# validity will still be checked at execution time
if destvar == destexpr.text:
ttype = desttrait.trait_type
if not ttype:
ttype = desttrait
srcval = srcexpr.evaluate()
if ttype.validate:
ttype.validate(destcomp, destvarname, srcval)
else:
# no validate function on destination trait. Most likely
# it's a property trait. No way to validate without
# unknown side effects. Have to wait until later when
# data actually gets passed via the connection.
pass
if src not in self._exprgraph:
self._exprgraph.add_node(src, expr=srcexpr)
if dest not in self._exprgraph:
self._exprgraph.add_node(dest, expr=destexpr)
self._exprgraph.add_edge(src, dest)
if pseudocomp is not None:
self._exprgraph[src][dest]['pcomp'] = pseudocomp
def find_referring_exprs(self, name):
[docs] """Returns a list of expression strings that reference the given name, which
can refer to either a variable or a component.
"""
return [node for node, data in self._exprgraph.nodes(data=True)
if data['expr'].refers_to(name)]
def _remove_disconnected_exprs(self):
# remove all expressions that are no longer connected to anything
to_remove = []
graph = self._exprgraph
for expr in graph.nodes():
if graph.in_degree(expr) == 0 and graph.out_degree(expr) == 0:
to_remove.append(expr)
graph.remove_nodes_from(to_remove)
return to_remove
def disconnect(self, srcpath, destpath=None):
[docs] """Disconnect the given expressions/variables/components.
Returns a list of edges to remove and a list of pseudocomponents
to remove.
"""
graph = self._exprgraph
to_remove = set()
exprs = []
pcomps = set()
if destpath is None:
exprs = self.find_referring_exprs(srcpath)
for expr in exprs:
to_remove.update(graph.edges(expr))
to_remove.update(graph.in_edges(expr))
else:
if srcpath in graph and destpath in graph:
to_remove.add((srcpath, destpath))
data = graph[srcpath][destpath]
if 'pcomp' in data:
pcomps.add(data['pcomp'].name)
else: # assume they're disconnecting two variables, so find connected exprs that refer to them
src_exprs = set(self.find_referring_exprs(srcpath))
dest_exprs = set(self.find_referring_exprs(destpath))
to_remove.update([(src, dest) for src, dest in graph.edges()
if src in src_exprs and dest in dest_exprs])
added = []
for src, dest in to_remove:
if src.startswith('_pseudo_'):
pcomp = getattr(self._scope, src.split('.', 1)[0])
elif dest.startswith('_pseudo_'):
pcomp = getattr(self._scope, dest.split('.', 1)[0])
else:
continue
added.extend(pcomp.list_connections())
pcomps.add(pcomp.name)
to_remove.update(added)
graph.remove_edges_from(to_remove)
graph.remove_nodes_from(exprs)
self._remove_disconnected_exprs()
return to_remove, pcomps
def check_connect(self, src, dest, scope):
[docs] """Check validity of connecting a source expression to a destination expression, and
determine if we need to create links to pseudocomps.
"""
if self.get_source(dest) is not None:
scope.raise_exception("'%s' is already connected to source '%s'" %
(dest, self.get_source(dest)), RuntimeError)
destexpr = ConnectedExprEvaluator(dest, scope, is_dest=True)
srcexpr = ConnectedExprEvaluator(src, scope,
getter='get_attr')
srccomps = srcexpr.get_referenced_compnames()
destcomps = list(destexpr.get_referenced_compnames())
if destcomps and destcomps[0] in srccomps:
raise RuntimeError("'%s' and '%s' refer to the same component." % (src, dest))
return srcexpr, destexpr, self._needs_pseudo(scope, srcexpr, destexpr)
def _needs_pseudo(self, parent, srcexpr, destexpr):
"""Return a non-None pseudo_type if srcexpr and destexpr require a
pseudocomp to be created.
"""
srcrefs = list(srcexpr.refs())
if srcrefs and srcrefs[0] != srcexpr.text:
# expression is more than just a simple variable reference,
# so we need a pseudocomp
return 'multi_var_expr'
destmeta = destexpr.get_metadata('units')
srcmeta = srcexpr.get_metadata('units')
srcunit = srcmeta[0][1] if srcmeta else None
destunit = destmeta[0][1] if destmeta else None
if destunit and srcunit and destunit != srcunit:
return 'units'
return None
def list_pseudocomps(self):
[docs] return [data['pcomp'].name for u, v, data in
self._exprgraph.edges(data=True) if 'pcomp' in data]