First acceptable solution
This commit is contained in:
@@ -1,92 +1,60 @@
|
||||
from typing import Any
|
||||
|
||||
from typing import Any, Callable, Set
|
||||
from .CFG_Node import *
|
||||
|
||||
|
||||
class CFG:
|
||||
def __init__(self, in_node: CFG_Node, out_node: CFG_Node, ast=None):
|
||||
self.in_node = in_node
|
||||
self.out_node = out_node
|
||||
def __init__(self, ast):
|
||||
start = CFG_START()
|
||||
start.dot_style = 'style=filled, color=gray'
|
||||
end = CFG_END()
|
||||
end.dot_style = 'style=filled, color=gray'
|
||||
|
||||
last = ast.cfa(start, end)
|
||||
if last is not None:
|
||||
last.add_child(end)
|
||||
|
||||
self.START = start
|
||||
self.END = end
|
||||
self.ast = ast
|
||||
|
||||
# If AST is provided, filter the graph by removing empty nodes
|
||||
if ast is not None:
|
||||
self._filter_graph()
|
||||
|
||||
def _filter_graph(self):
|
||||
"""
|
||||
Filter the CFG by removing empty nodes and rewiring edges.
|
||||
This should be done once during construction, not during to_dot().
|
||||
"""
|
||||
# Collect all nodes in the graph
|
||||
all_nodes = set()
|
||||
self._collect_nodes(self.in_node, all_nodes)
|
||||
|
||||
# Identify nodes to remove
|
||||
nodes_to_remove = [node for node in all_nodes if self._should_remove_node(node)]
|
||||
|
||||
# Remove nodes and rewrite edges
|
||||
|
||||
# Remove empty nodes and rewire edges
|
||||
all_nodes = self.nodes()
|
||||
nodes_to_remove = [node for node in all_nodes if node.is_empty()]
|
||||
for node in nodes_to_remove:
|
||||
self._remove_node_and_rewire(node)
|
||||
|
||||
def _collect_nodes(self, node, node_set):
|
||||
"""Recursively collect all nodes in the graph"""
|
||||
self.__remove_and_rewire(node)
|
||||
|
||||
def nodes(self):
|
||||
all_nodes = set()
|
||||
self.__collect_nodes(self.START, all_nodes)
|
||||
return all_nodes
|
||||
|
||||
def __collect_nodes(self, node, node_set):
|
||||
if node in node_set:
|
||||
return
|
||||
node_set.add(node)
|
||||
for child in node.children:
|
||||
self._collect_nodes(child, node_set)
|
||||
self.__collect_nodes(child, node_set)
|
||||
|
||||
def _should_remove_node(self, node):
|
||||
"""Determine if a node should be removed from the graph"""
|
||||
# Remove empty nodes (nodes with no meaningful content)
|
||||
# Check for both None and "None" string
|
||||
if hasattr(node, 'label') and ((node.label is None) or (node.label == "None")):
|
||||
# Nodes with AST nodes should NOT be removed - they will get labels from AST
|
||||
if node.ast_node is not None:
|
||||
return False
|
||||
# Also keep global START nodes (they have label=None but should be shown)
|
||||
if hasattr(node, 'dot_label') and node.dot_label() == "START":
|
||||
return False
|
||||
# Remove nodes that have no AST and no meaningful label
|
||||
return True
|
||||
|
||||
# Remove global END nodes (those without function names)
|
||||
if hasattr(node, 'dot_label'):
|
||||
if node.dot_label() in ["END"]:
|
||||
# Keep function-specific END nodes, skip global ones
|
||||
if hasattr(node, 'label') and node.label and '(' in node.label and ')' in node.label:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _remove_node_and_rewire(self, node):
|
||||
"""Remove a node from the graph and rewire edges to bypass it"""
|
||||
# Store original children before modification
|
||||
def __remove_and_rewire(self, node):
|
||||
original_children = list(node.children)
|
||||
|
||||
# For each parent, rewire edges to bypass this node
|
||||
|
||||
for parent in list(node.parents):
|
||||
if node in parent.children:
|
||||
# Find appropriate targets based on node type
|
||||
if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond":
|
||||
# For diamond nodes, preserve T/F branches
|
||||
# For diamond nodes, preserve the true and false bodies
|
||||
if isinstance(node, CFG_DIAMOND):
|
||||
targets = []
|
||||
if len(original_children) >= 1:
|
||||
true_target = self._find_first_non_empty_child(original_children[0])
|
||||
true_target = self.__first_filled_child(original_children[0])
|
||||
if true_target:
|
||||
targets.append(true_target)
|
||||
if len(original_children) >= 2:
|
||||
false_target = self._find_first_non_empty_child(original_children[1])
|
||||
false_target = self.__first_filled_child(original_children[1])
|
||||
if false_target:
|
||||
targets.append(false_target)
|
||||
# For regular nodes, find all non-empty targets
|
||||
else:
|
||||
# For regular nodes, find all non-empty targets
|
||||
targets = []
|
||||
for child in original_children:
|
||||
target = self._find_first_non_empty_child(child)
|
||||
target = self.__first_filled_child(child)
|
||||
if target and target not in targets:
|
||||
targets.append(target)
|
||||
|
||||
@@ -101,97 +69,49 @@ class CFG:
|
||||
node.parents.clear()
|
||||
node.children.clear()
|
||||
|
||||
def _find_first_non_empty_child(self, node):
|
||||
"""Find the first non-empty descendant of a node"""
|
||||
if not self._should_remove_node(node):
|
||||
def __first_filled_child(self, node):
|
||||
if not node.is_empty():
|
||||
return node
|
||||
|
||||
# Recursively check children
|
||||
for child in sorted(node.children, key=lambda n: n.id):
|
||||
result = self._find_first_non_empty_child(child)
|
||||
result = self.__first_filled_child(child)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def to_dot(self) -> str:
|
||||
"""
|
||||
Convert the CFG to DOT format.
|
||||
This method should ONLY handle formatting, not graph modifications.
|
||||
All graph filtering and modifications should be done in the constructor.
|
||||
"""
|
||||
visited = set()
|
||||
lines = ["digraph CFG {"]
|
||||
lines.append(' node [fontname="Helvetica"];')
|
||||
|
||||
def node_label(node: CFG_Node) -> str | None | Any:
|
||||
# Use custom label if available
|
||||
if hasattr(node, 'label') and node.label:
|
||||
# Remove node ID from label for certain node types
|
||||
if isinstance(node, (CFG_START, CFG_END, CFG_CALL, CFG_RETURN)):
|
||||
return node.label
|
||||
else:
|
||||
return node.label
|
||||
|
||||
# Base label from the node
|
||||
base = node.dot_label() if hasattr(node, "dot_label") else ""
|
||||
|
||||
# Semantic label from AST
|
||||
if node.ast_node is not None:
|
||||
semantic = str(node.ast_node)
|
||||
label_content = f"{base}\n{semantic}" if base else semantic
|
||||
return label_content
|
||||
|
||||
return base if base else None
|
||||
|
||||
def node_shape(node: CFG_Node) -> str:
|
||||
return node.dot_shape() if hasattr(node, "dot_shape") else "box"
|
||||
|
||||
def node_style(node: CFG_Node) -> str:
|
||||
# Add styling for special node types
|
||||
styles = []
|
||||
if hasattr(node, 'label') and node.label:
|
||||
if node.label.startswith('CALL') or node.label.startswith('RET'):
|
||||
styles.append('style=filled')
|
||||
styles.append('color=orange')
|
||||
elif node.label.startswith('START') or node.label.startswith('END'):
|
||||
styles.append('style=filled')
|
||||
styles.append('color=green')
|
||||
return ', '.join(styles) if styles else ''
|
||||
|
||||
def visit(node: CFG_Node):
|
||||
if node.id in visited:
|
||||
return
|
||||
|
||||
visited.add(node.id)
|
||||
|
||||
label = node_label(node)
|
||||
if label is None:
|
||||
# This shouldn't happen if the constructor did its job properly
|
||||
return
|
||||
|
||||
shape = node_shape(node)
|
||||
style = node_style(node)
|
||||
lines = ["digraph CFG {", ' node [fontname="Helvetica"];']
|
||||
|
||||
def emit(node: CFG_Node):
|
||||
label = node.dot_label()
|
||||
shape = node.dot_shape
|
||||
style = node.dot_style
|
||||
style_str = f", {style}" if style else ""
|
||||
lines.append(
|
||||
f' n{node.id} [label="{label}", shape={shape}{style_str}];'
|
||||
)
|
||||
lines.append(f' n{node.id} [label="{label}", shape={shape}{style_str}];')
|
||||
|
||||
# Add edges to children
|
||||
for i, child in enumerate(sorted(node.children, key=lambda n: n.id)):
|
||||
# Add edge labels for diamond nodes (conditional branches)
|
||||
edge_label = ""
|
||||
if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond":
|
||||
if isinstance(node, CFG_DIAMOND):
|
||||
if i == 0:
|
||||
edge_label = ' [label="T"]'
|
||||
elif i == 1:
|
||||
edge_label = ' [label="F"]'
|
||||
|
||||
lines.append(f" n{node.id} -> n{child.id}{edge_label};")
|
||||
visit(child)
|
||||
|
||||
# Start the CFG traversal from the entry node
|
||||
visit(self.in_node)
|
||||
self.traverse(emit, start=self.START)
|
||||
lines.append("}")
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines)
|
||||
|
||||
# Reusable traversal function
|
||||
def traverse(self, fn: Callable[[CFG_Node], Any], start: CFG_Node | None = None) -> None:
|
||||
start = start or self.START
|
||||
visited: Set[int] = set()
|
||||
|
||||
def visit(node: CFG_Node):
|
||||
if node.id in visited:
|
||||
return
|
||||
visited.add(node.id)
|
||||
fn(node)
|
||||
for child in sorted(node.children, key=lambda n: n.id):
|
||||
visit(child)
|
||||
visit(start)
|
||||
@@ -5,7 +5,10 @@ class CFG_Node:
|
||||
self.ast_node = ast_node
|
||||
self.children = set()
|
||||
self.parents = set()
|
||||
self.label = None # Optional label for the node
|
||||
|
||||
self.label = None
|
||||
self.dot_shape = 'box'
|
||||
self.dot_style = ''
|
||||
|
||||
self.id = CFG_Node.__counter
|
||||
CFG_Node.__counter += 1
|
||||
@@ -36,44 +39,58 @@ class CFG_Node:
|
||||
parent.children.remove(self)
|
||||
self.parents.remove(parent)
|
||||
|
||||
def __str__(self):
|
||||
if self.label:
|
||||
return f"CFG_Node({self.id}, label='{self.label}')"
|
||||
elif self.ast_node:
|
||||
return f"CFG_Node({self.id}, ast={type(self.ast_node).__name__})"
|
||||
else:
|
||||
return f"CFG_Node({self.id})"
|
||||
def dot_label(self):
|
||||
# Prioritize custom label
|
||||
if self.label is not None:
|
||||
return self.label
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
# Build label from AST node
|
||||
if self.ast_node is not None:
|
||||
return str(self.ast_node)
|
||||
|
||||
return None
|
||||
|
||||
def is_filled(self):
|
||||
return not self.is_empty()
|
||||
|
||||
def is_empty(self):
|
||||
# Node is empty if it has no label and no related AST node
|
||||
if self.label is None or self.label == "None":
|
||||
if self.ast_node is not None:
|
||||
# Node belongs to a ast node
|
||||
return False
|
||||
return True
|
||||
# Node is required for the control flow
|
||||
return False
|
||||
|
||||
class CFG_START(CFG_Node):
|
||||
def dot_shape(self):
|
||||
return "box"
|
||||
|
||||
def dot_label(self):
|
||||
return "START"
|
||||
|
||||
def __init__(self, ast_node=None):
|
||||
super().__init__(ast_node)
|
||||
self.dot_shape = "ellipse"
|
||||
self.dot_style = 'style=filled, color=green'
|
||||
self.label = "START"
|
||||
|
||||
class CFG_END(CFG_Node):
|
||||
def dot_shape(self):
|
||||
return "box"
|
||||
|
||||
def dot_label(self):
|
||||
return "END"
|
||||
|
||||
def __init__(self, ast_node=None):
|
||||
super().__init__(ast_node)
|
||||
self.dot_shape = "ellipse"
|
||||
self.dot_style = 'style=filled, color=green'
|
||||
self.label = "END"
|
||||
|
||||
class CFG_DIAMOND(CFG_Node):
|
||||
def dot_shape(self):
|
||||
return "diamond"
|
||||
|
||||
def __init__(self, ast_node=None):
|
||||
super().__init__(ast_node)
|
||||
self.dot_shape = "diamond"
|
||||
self.label = "<?>"
|
||||
|
||||
class CFG_CALL(CFG_Node):
|
||||
def dot_shape(self):
|
||||
return "box"
|
||||
|
||||
def __init__(self, ast_node=None):
|
||||
super().__init__(ast_node)
|
||||
self.dot_style = 'style=filled, color=orange'
|
||||
self.dot_shape = "box"
|
||||
|
||||
class CFG_RETURN(CFG_Node):
|
||||
def dot_shape(self):
|
||||
return "box"
|
||||
def __init__(self, ast_node=None):
|
||||
super().__init__(ast_node)
|
||||
self.dot_style = 'style=filled, color=orange'
|
||||
self.dot_shape = "box"
|
||||
|
||||
Reference in New Issue
Block a user