Refactor cfg_build.py

This commit is contained in:
Jan-Niclas Loosen
2026-01-22 18:18:13 +01:00
parent 489f385161
commit 3abe8581b5
3 changed files with 124 additions and 77 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

After

Width:  |  Height:  |  Size: 66 KiB

View File

@@ -14,159 +14,202 @@ import syntax
FUNCTIONS = {} FUNCTIONS = {}
class CONST(compiler.CONST): class CONST(compiler.CONST):
def cfa(self, pred, end): def cfa(self, pred, end = None):
n = CFG_Node(self) node = CFG_Node(self)
pred.add_child(n) pred.add_child(node)
n.add_child(end) if end else None
return n # Attach the end node if it is provided
if end is not None:
node.add_child(end)
return node
class ID(compiler.ID): class ID(compiler.ID):
def cfa(self, pred, end): def cfa(self, pred, end = None):
n = CFG_Node(self) node = CFG_Node(self)
pred.add_child(n) pred.add_child(node)
n.add_child(end) if end else None
return n # Attach the end node if it is provided
if end is not None:
node.add_child(end)
return node
class AOP(compiler.AOP): class AOP(compiler.AOP):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# Create nodes for each operand separately (like the example) # Create nodes for the used expressions and attach
left_node = self.arg1.cfa(pred, None) left_node = self.arg1.cfa(pred)
right_node = self.arg2.cfa(left_node, None) right_node = self.arg2.cfa(left_node)
# Create the comparison node with just the operator # Create the operator node and attach
op_node = CFG_Node(self) op_node = CFG_Node(self)
op_node.label = f"{self.operator}" op_node.label = f"{str(self.arg1)} {self.operator} {str(self.arg2)}"
right_node.add_child(op_node) right_node.add_child(op_node)
op_node.add_child(end) if end else None
# Attach the end node if it is provided
if end is not None:
op_node.add_child(end)
return op_node return op_node
class COMP(compiler.COMP): class COMP(compiler.COMP):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# Create nodes for each operand separately (like the example) # Create nodes for the used expressions and attach
left_node = self.arg1.cfa(pred, None) left_node = self.arg1.cfa(pred)
right_node = self.arg2.cfa(left_node, None) right_node = self.arg2.cfa(left_node)
# Create the comparison node with just the operator # Create the comparison node and attach
comp_node = CFG_Node(self) comp_node = CFG_Node(self)
comp_node.label = f"{self.operator}" comp_node.label = f"{str(self.arg1)} {self.operator} {str(self.arg2)}"
right_node.add_child(comp_node) right_node.add_child(comp_node)
comp_node.add_child(end) if end else None
# Attach the end node if it is provided
if end is not None:
comp_node.add_child(end)
return comp_node return comp_node
class EQOP(compiler.EQOP): class EQOP(compiler.EQOP):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# Create nodes for each operand separately (like the example) # Create nodes for the used expressions and attach
left_node = self.arg1.cfa(pred, None) left_node = self.arg1.cfa(pred)
right_node = self.arg2.cfa(left_node, None) right_node = self.arg2.cfa(left_node)
# Create the equation node with just the operator # Create the comparison node and attach
eqop_node = CFG_Node(self) eqop_node = CFG_Node(self)
eqop_node.label = f"{self.operator}" eqop_node.label = f"{str(self.arg1)} {self.operator} {str(self.arg2)}"
right_node.add_child(eqop_node) right_node.add_child(eqop_node)
eqop_node.add_child(end) if end else None
# Attach the end node if it is provided
if end is not None:
eqop_node.add_child(end)
return eqop_node return eqop_node
class LOP(compiler.LOP): class LOP(compiler.LOP):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# Create nodes for each operand separately # Create nodes for each operand separately
left_node = self.arg1.cfa(pred, None) left_node = self.arg1.cfa(pred)
right_node = self.arg2.cfa(left_node, None) right_node = self.arg2.cfa(left_node)
# Create the logical operation node with just the operator # Create the logical operation node with just the operator
lop_node = CFG_Node(self) lop_node = CFG_Node(self)
lop_node.label = f"{self.operator}" lop_node.label = f"{str(self.arg1)} {self.operator} {str(self.arg2)}"
right_node.add_child(lop_node) right_node.add_child(lop_node)
lop_node.add_child(end) if end else None
# Attach the end node if it is provided
if end is not None:
lop_node.add_child(end)
return lop_node return lop_node
class ASSIGN(compiler.ASSIGN): class ASSIGN(compiler.ASSIGN):
def cfa(self, pred, end): def cfa(self, pred, end = None):
expr_node = self.expr.cfa(pred, None) # Unwraps expressions needed for assignment
expr_node = self.expr.cfa(pred)
# Assignment node
assign_node = CFG_Node(self) assign_node = CFG_Node(self)
expr_node.add_child(assign_node) expr_node.add_child(assign_node)
assign_node.add_child(end) if end else None
# Attach the end node if it is provided
if end is not None:
assign_node.add_child(end)
return assign_node return assign_node
class SEQ(compiler.SEQ): class SEQ(compiler.SEQ):
def cfa(self, pred, end): def cfa(self, pred, end = None):
mid = self.exp1.cfa(pred, None) mid = self.exp1.cfa(pred)
if mid is None: if mid is None:
return None return None
return self.exp2.cfa(mid, end) return self.exp2.cfa(mid, end)
class IF(compiler.IF): class IF(compiler.IF):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# Unwraps expressions needed for the condition
cond_node = self.cond.cfa(pred, None) cond_node = self.cond.cfa(pred, None)
# Attach junction node
diamond = CFG_DIAMOND(self.cond) diamond = CFG_DIAMOND(self.cond)
diamond.label = "<?>" # Use simple diamond label diamond.label = "<?>"
cond_node.add_child(diamond) cond_node.add_child(diamond)
# Define start and end entry and unwraps expressions
then_entry = CFG_Node() then_entry = CFG_Node()
else_entry = CFG_Node()
diamond.add_child(then_entry) diamond.add_child(then_entry)
else_entry = CFG_Node()
diamond.add_child(else_entry) diamond.add_child(else_entry)
# Attach the end node if it is provided
join = CFG_Node() join = CFG_Node()
join.add_child(end) if end else None if end is not None:
join.add_child(end)
# Connect the extracted expressions with the join
then_end = self.exp1.cfa(then_entry, join) then_end = self.exp1.cfa(then_entry, join)
else_end = self.exp2.cfa(else_entry, join) else_end = self.exp2.cfa(else_entry, join)
return join return join
class WHILE(compiler.WHILE): class WHILE(compiler.WHILE):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# Handle different types of conditions
if hasattr(self.cond, 'arg1') and hasattr(self.cond, 'arg2'): if hasattr(self.cond, 'arg1') and hasattr(self.cond, 'arg2'):
# This is a comparison operation (e.g., a > b) # This is a comparison operation (e.g., a > b)
# Create the condition evaluation nodes # Create the condition evaluation nodes
left_node = self.cond.arg1.cfa(pred, None) left_node = self.cond.arg1.cfa(pred)
right_node = self.cond.arg2.cfa(left_node, None) right_node = self.cond.arg2.cfa(left_node)
# Create the comparison node and attach
comp_node = CFG_Node(self.cond) comp_node = CFG_Node(self.cond)
comp_node.label = f"({str(self.cond.arg1)} {self.cond.operator} {str(self.cond.arg2)})" comp_node.label = f"{str(self.cond.arg1)} {self.cond.operator} {str(self.cond.arg2)}"
right_node.add_child(comp_node) right_node.add_child(comp_node)
else: else:
# This is a simple condition (e.g., constant true/false or single expression) # This is a simple condition (e.g., constant true/false)
cond_node = self.cond.cfa(pred, None) cond_node = self.cond.cfa(pred)
comp_node = cond_node comp_node = cond_node
# Create the diamond node # Attach junction node
diamond = CFG_DIAMOND(self.cond) diamond = CFG_DIAMOND(self.cond)
diamond.label = "<>" # Use simple diamond label diamond.label = "<?>"
comp_node.add_child(diamond) comp_node.add_child(diamond)
# For the true branch, go to body # Unwrap the loop body
body_entry = CFG_Node() body_entry = CFG_Node()
diamond.add_child(body_entry) diamond.add_child(body_entry)
# The body should connect back to the start of condition evaluation # The body should connect back to the start of condition evaluation
body_end = self.body.cfa(body_entry, None) body_end = self.body.cfa(body_entry)
if body_end is not None: if body_end is not None:
# Connect body end back to the condition evaluation # Connect the body end back to the condition evaluation
if hasattr(self.cond, 'arg1') and hasattr(self.cond, 'arg2'): if hasattr(self.cond, 'arg1') and hasattr(self.cond, 'arg2'):
body_end.add_child(left_node) body_end.add_child(left_node)
else: else:
body_end.add_child(pred) # For simple conditions, go back to start body_end.add_child(pred)
# Attach joining node
after = CFG_Node() after = CFG_Node()
diamond.add_child(after) diamond.add_child(after)
after.add_child(end) if end else None
# Attach the end node if it is provided
if end is not None:
after.add_child(end)
return after return after
class CALL(compiler.CALL): class CALL(compiler.CALL):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# Create nodes for all argument values # Create nodes for all argument values
current_arg_node = pred current_arg_node = pred
for i, arg in enumerate(self.arg): for i, arg in enumerate(self.arg):
# Process argument through its cfa method to create proper CFG structure current_arg_node = arg.cfa(current_arg_node)
current_arg_node = arg.cfa(current_arg_node, None)
# Create and attach the call node
call_node = CFG_CALL(self) call_node = CFG_CALL(self)
call_node.label = f"CALL {self.f_name}" call_node.label = f"CALL {self.f_name}"
current_arg_node.add_child(call_node) current_arg_node.add_child(call_node)
# Create and attach the exit node
cont = CFG_Node() cont = CFG_Node()
cont.add_child(end) if end else None if end is not None:
cont.add_child(end)
# Find the functions in the function list
if self.f_name not in FUNCTIONS: if self.f_name not in FUNCTIONS:
raise RuntimeError(f"Call to undefined function '{self.f_name}'") raise RuntimeError(f"Call to undefined function '{self.f_name}'")
# Determine start and exit node of the function
f_start, f_end = FUNCTIONS[self.f_name] f_start, f_end = FUNCTIONS[self.f_name]
# Create return node from function # Create return node from function
@@ -175,10 +218,11 @@ class CALL(compiler.CALL):
f_end.add_child(return_node) f_end.add_child(return_node)
return_node.add_child(cont) return_node.add_child(cont)
# Span the start and exit nodes to the method body
call_node.add_child(f_start) call_node.add_child(f_start)
# Add direct edge from CALL to RET node (for the expected structure)
call_node.add_child(return_node) call_node.add_child(return_node)
# TODO: Why only g? Also f can be recursive.
# For recursive calls, we need to ensure proper return value flow # For recursive calls, we need to ensure proper return value flow
# In expressions like g(x)+x, the return value from g(x) flows to the continuation # In expressions like g(x)+x, the return value from g(x) flows to the continuation
# This is especially important for recursive functions where multiple calls return values # This is especially important for recursive functions where multiple calls return values
@@ -192,23 +236,27 @@ class CALL(compiler.CALL):
class DECL(compiler.DECL): class DECL(compiler.DECL):
def cfa(self, pred, end): def cfa(self, pred, end):
# Check if function is already registered (from first pass in LET) # Check if a function is already registered
if self.f_name in FUNCTIONS: if self.f_name in FUNCTIONS:
f_start, f_end = FUNCTIONS[self.f_name] f_start, f_end = FUNCTIONS[self.f_name]
else: else:
# Span the method body into a start and end node
f_start = CFG_START(self) f_start = CFG_START(self)
f_start.label = f"START {self.f_name}({', '.join(self.params)})" f_start.label = f"START {self.f_name}({', '.join(self.params)})"
f_end = CFG_END(self) f_end = CFG_END(self)
f_end.label = f"END {self.f_name}({', '.join(self.params)})" f_end.label = f"END {self.f_name}({', '.join(self.params)})"
FUNCTIONS[self.f_name] = (f_start, f_end) FUNCTIONS[self.f_name] = (f_start, f_end)
# Unwrap the method body
body_end = self.body.cfa(f_start, f_end) body_end = self.body.cfa(f_start, f_end)
# Attach the end node if it is provided
if body_end is not None: if body_end is not None:
body_end.add_child(f_end) body_end.add_child(f_end)
return pred return pred
class LET(compiler.LET): class LET(compiler.LET):
def cfa(self, pred, end): def cfa(self, pred, end = None):
# First pass: Register all function declarations # First pass: Register all function declarations
decls = self.decl if isinstance(self.decl, list) else [self.decl] decls = self.decl if isinstance(self.decl, list) else [self.decl]
for d in decls: for d in decls:
@@ -220,20 +268,19 @@ class LET(compiler.LET):
f_end.label = f"END {d.f_name}({', '.join(d.params)})" f_end.label = f"END {d.f_name}({', '.join(d.params)})"
FUNCTIONS[d.f_name] = (f_start, f_end) FUNCTIONS[d.f_name] = (f_start, f_end)
# Create global entry node # Create a global entry node for the function
global_entry = CFG_Node() global_entry = CFG_Node()
global_entry.label = "None" global_entry.label = "None"
pred.add_child(global_entry) pred.add_child(global_entry)
current = global_entry current = global_entry
# Second pass: Process declarations and build CFGs # Generate function declarations
for d in decls: for d in decls:
current = d.cfa(current, None) current = d.cfa(current, None)
if current is None: if current is None:
return None return None
# Process the body (function call) # Unwrap the body
body_result = self.body.cfa(current, end) body_result = self.body.cfa(current, end)
# Create global exit node # Create global exit node
@@ -241,9 +288,10 @@ class LET(compiler.LET):
global_exit.label = "None" global_exit.label = "None"
if body_result is not None: if body_result is not None:
body_result.add_child(global_exit) body_result.add_child(global_exit)
# Attach the end node if it is provided
if end is not None: if end is not None:
global_exit.add_child(end) global_exit.add_child(end)
return global_exit return global_exit
class RETURN(syntax.EXPRESSION): class RETURN(syntax.EXPRESSION):

View File

@@ -35,7 +35,6 @@ def make_cfg(ast):
return CFG(start, end) return CFG(start, end)
# Renders a diagram of the AST # Renders a diagram of the AST
def render_diagram(dot_string: str): def render_diagram(dot_string: str):
# Set DPI for PNG # Set DPI for PNG