from typing import Dict, Optional, List, Any, Tuple, override from MaMa import MaMa class MaMaMa(MaMa): def __init__(self, prog, stack=None) -> None: # Store macros self.macros: Dict[str, Dict[str, Any]] = {} self.initial_macros: Dict[str, Dict[str, Any]] = {} # Trace nested macros self._macro_trace: Dict[int, List[str]] = {} super().__init__(prog, stack) self.initial_macros["__top__"] = dict(self.prog) # Register a new macro, which is a name and a sequence of micros def add_macro(self, name: str, prog: List[str] | Dict[int, str], args: List[str] = None) -> None: if isinstance(prog, list): prog = {i: instr for i, instr in enumerate(prog)} self.macros[name] = {"prog": prog, "args": args} self.initial_macros[name] = dict(self.macros[name]) # Safely evaluate simple arithmetic expressions def _eval_expr(self, expr: Any, env: Dict[str, Any]) -> Any: if isinstance(expr, (int, float)): return expr if isinstance(expr, str): # Replace variables in the expression with env values safe_expr = expr for k, v in env.items(): safe_expr = safe_expr.replace(str(k), str(v)) try: return eval(safe_expr, {"__builtins__": {}}) except Exception: return expr return expr # Automatically flatten macros and then execute @override def run(self, max_steps: int = 1000): self.__flatten_macro() return super().run(max_steps) # Flatten macros recursively with expression support def __flatten_macro(self) -> None: def expand(prog: Dict[int, str], stack: List[str], env: Dict[str, Any]) -> List[Tuple[str, List[str]]]: out: List[Tuple[str, List[str]]] = [] for _, micro in sorted(prog.items()): name, args = self.decode(micro) # substitute arguments if defined in env and evaluate expressions if args: args = [self._eval_expr(env.get(str(a), a), env) for a in args] micro = f"{name}({','.join(map(str, args))})" if name in self.macros: macro = self.macros[name] params = macro.get("args") or [] new_env = env.copy() if args and params: for p, v in zip(params, args): new_env[p] = v out.extend(expand(macro["prog"], stack + [name], new_env)) else: out.append((micro, list(stack))) return out expanded = expand(self.prog, [], {}) self.prog = {i: call for i, (call, _) in enumerate(expanded)} self._macro_trace = {i: macros for i, (_, macros) in enumerate(expanded)} # Build program structure with parameter info (no change to _macro_trace) @override def structure(self) -> Dict[int, Dict[str, Any]]: struct: Dict[int, Dict[str, Any]] = {} # extract macro calls with their argument values from original top-level prog top_calls: Dict[str, Dict[str, Any]] = {} for _, micro in sorted(self.initial_macros.get("__top__", self.prog).items()): name, args = self.decode(micro) if name in self.macros: macro = self.macros[name] params = macro.get("args") or [] if args and params: top_calls[name] = dict(zip(params, args)) # build structure for i, micro in sorted(self.prog.items()): entry = {"micro": micro, "macros": []} for macro_name in self._macro_trace.get(i, []): macro_info = {"name": macro_name, "params": {}} if macro_name in self.macros: param_names = self.macros[macro_name].get("args") or [] macro_info["params"] = { p: top_calls.get(macro_name, {}).get(p, None) for p in param_names } entry["macros"].append(macro_info) struct[i] = entry return struct