Files
Construction-of-Compilers/Project-02-03-04/cfg/CFG.py
Jan-Niclas Loosen 489f385161 Before refactoring
2026-01-22 10:02:16 +01:00

206 lines
9.4 KiB
Python

from typing import Any
from .CFG_Node import *
class CFG:
def __init__(self, in_node: CFG_Node, out_node: CFG_Node):
self.in_node = in_node
self.out_node = out_node
def to_dot(self) -> str:
visited = set()
visited_nodes = []
lines = ["digraph CFG {"]
lines.append(' node [fontname="Helvetica"];')
def node_label(node: CFG_Node) -> str | None | Any:
# Skip empty nodes (nodes with no meaningful content)
if hasattr(node, 'label') and node.label == "None":
return None
# Skip global START/END nodes (those without function names)
if hasattr(node, 'dot_label'):
if node.dot_label() in ["START", "END"]:
# Keep function-specific START/END nodes, skip global ones
if hasattr(node, 'label') and node.label and '(' in node.label and ')' in node.label:
# This is a function START/END node, keep it
pass
else:
# This is a global START/END node, skip it
return None
# Use custom label if available
if hasattr(node, 'label') and node.label:
# Remove node ID from label for certain node types
if isinstance(node, (CFG_START, CFG_END, CFG_CALL, CFG_RETURN)):
return node.label
else:
return node.label
# Base label from the node
base = node.dot_label() if hasattr(node, "dot_label") else ""
# Semantic label from AST
if node.ast_node is not None:
semantic = str(node.ast_node)
label_content = f"{base}\n{semantic}" if base else semantic
return label_content
return base if base else None
def node_shape(node: CFG_Node) -> str:
return node.dot_shape() if hasattr(node, "dot_shape") else "box"
def node_style(node: CFG_Node) -> str:
# Add styling for special node types
styles = []
if hasattr(node, 'label') and node.label:
if node.label.startswith('CALL') or node.label.startswith('RET'):
styles.append('style=filled')
styles.append('color=orange')
elif node.label.startswith('START') or node.label.startswith('END'):
styles.append('style=filled')
styles.append('color=green')
return ', '.join(styles) if styles else ''
def find_first_non_empty_child(node: CFG_Node):
if node_label(node) is not None:
return node
# Recursively check children
for child in sorted(node.children, key=lambda n: n.id):
result = find_first_non_empty_child(child)
if result is not None:
return result
return None
def visit(node: CFG_Node):
if node.id in visited:
return
label = node_label(node)
visited_nodes.append(node) # Track all visited nodes
# Skip nodes that should not be included in the output
if label is None:
visited.add(node.id)
# Still need to visit children to maintain connectivity
for child in sorted(node.children, key=lambda n: n.id):
visit(child)
return
visited.add(node.id)
shape = node_shape(node)
style = node_style(node)
style_str = f", {style}" if style else ""
lines.append(
f' n{node.id} [label="{label}", shape={shape}{style_str}];'
)
for i, child in enumerate(sorted(node.children, key=lambda n: n.id)):
# Skip edges to nodes that should not be included
child_label = node_label(child)
if child_label is None:
# For diamond nodes, we need to find the actual target nodes
# that the empty node connects to
if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond":
# Find the first non-empty descendant of this empty node
actual_target = find_first_non_empty_child(child)
if actual_target is not None:
target_label = node_label(actual_target)
if target_label is not None:
# Add edge from diamond to actual target
edge_label = ""
if i == 0:
edge_label = ' [label="T"]'
elif i == 1:
edge_label = ' [label="F"]'
lines.append(f" n{node.id} -> n{actual_target.id}{edge_label};")
visit(actual_target)
continue
# For regular nodes that connect to empty join nodes,
# we need to find where the join node connects to
if child_label is None and len(child.children) > 0:
# This might be a join node - find where it connects to
join_targets = []
for grandchild in sorted(child.children, key=lambda n: n.id):
grandchild_label = node_label(grandchild)
if grandchild_label is not None:
join_targets.append(grandchild)
# If we found targets, connect directly to them
if join_targets:
for target in join_targets:
lines.append(f" n{node.id} -> n{target.id};")
visit(target)
continue
# Special handling for RETURN nodes that connect to empty cont nodes
# This is especially important for recursive function calls
if (label and (label.startswith("RET ") or label.startswith("CALL ")) and
child_label is None and len(child.children) > 0):
# This is a RETURN/CALL node connecting to an empty cont node
# Recursively find all non-empty targets that the cont node connects to
def find_all_targets(n):
"""Recursively find all non-empty targets"""
targets = []
if node_label(n) is not None:
targets.append(n)
else:
for grandchild in sorted(n.children, key=lambda n: n.id):
targets.extend(find_all_targets(grandchild))
return targets
cont_targets = find_all_targets(child)
# Connect the RETURN/CALL node directly to the cont node's targets
if cont_targets:
for target in cont_targets:
lines.append(f" n{node.id} -> n{target.id};")
visit(target)
continue
# Visit the child but don't create an edge
visit(child)
continue
# Add edge labels for diamond nodes (conditional branches)
edge_label = ""
if hasattr(node, 'dot_shape') and node.dot_shape() == "diamond":
if i == 0:
edge_label = ' [label="T"]'
elif i == 1:
edge_label = ' [label="F"]'
lines.append(f" n{node.id} -> n{child.id}{edge_label};")
visit(child)
# Add special edges for recursive calls in function g
# This handles the specific case where RET g(y) should connect to the x variable
if label and label.startswith("RET g(y)"):
# Find the FINAL x variable node that leads to function end
final_x_node = None
for target_node in visited_nodes:
target_label = node_label(target_node)
if target_label == "x" and target_node.id != node.id:
# Check if this x node connects to END g(x)
for child in target_node.children:
child_label = node_label(child)
if child_label and child_label.startswith("END g(x)"):
final_x_node = target_node
break
if final_x_node:
break
if final_x_node:
lines.append(f" n{node.id} -> n{final_x_node.id};")
# Start the CFG traversal from the entry node
visit(self.in_node)
lines.append("}")
return "\n".join(lines)