100 lines
4.2 KiB
Python
100 lines
4.2 KiB
Python
from typing import Dict, Optional, List, Any, Tuple, override
|
|
from MaMa import MaMa
|
|
|
|
class MaMaMa(MaMa):
|
|
def __init__(self, prog, stack=None) -> None:
|
|
# Store macros
|
|
self.macros: Dict[str, Dict[str, Any]] = {}
|
|
self.initial_macros: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# Trace nested macros
|
|
self._macro_trace: Dict[int, List[str]] = {}
|
|
super().__init__(prog, stack)
|
|
self.initial_macros["__top__"] = dict(self.prog)
|
|
|
|
# Register a new macro, which is a name and a sequence of micros
|
|
def add_macro(self, name: str, prog: List[str] | Dict[int, str], args: List[str] = None) -> None:
|
|
if isinstance(prog, list):
|
|
prog = {i: instr for i, instr in enumerate(prog)}
|
|
self.macros[name] = {"prog": prog, "args": args}
|
|
self.initial_macros[name] = dict(self.macros[name])
|
|
|
|
# Safely evaluate simple arithmetic expressions
|
|
def _eval_expr(self, expr: Any, env: Dict[str, Any]) -> Any:
|
|
if isinstance(expr, (int, float)):
|
|
return expr
|
|
if isinstance(expr, str):
|
|
# Replace variables in the expression with env values
|
|
safe_expr = expr
|
|
for k, v in env.items():
|
|
safe_expr = safe_expr.replace(str(k), str(v))
|
|
try:
|
|
return eval(safe_expr, {"__builtins__": {}})
|
|
except Exception:
|
|
return expr
|
|
return expr
|
|
|
|
# Automatically flatten macros and then execute
|
|
@override
|
|
def run(self, max_steps: int = 1000):
|
|
self.__flatten_macro()
|
|
return super().run(max_steps)
|
|
|
|
# Flatten macros recursively with expression support
|
|
def __flatten_macro(self) -> None:
|
|
def expand(prog: Dict[int, str], stack: List[str], env: Dict[str, Any]) -> List[Tuple[str, List[str]]]:
|
|
out: List[Tuple[str, List[str]]] = []
|
|
for _, micro in sorted(prog.items()):
|
|
name, args = self.decode(micro)
|
|
|
|
# substitute arguments if defined in env and evaluate expressions
|
|
if args:
|
|
args = [self._eval_expr(env.get(str(a), a), env) for a in args]
|
|
micro = f"{name}({','.join(map(str, args))})"
|
|
|
|
if name in self.macros:
|
|
macro = self.macros[name]
|
|
params = macro.get("args") or []
|
|
new_env = env.copy()
|
|
if args and params:
|
|
for p, v in zip(params, args):
|
|
new_env[p] = v
|
|
out.extend(expand(macro["prog"], stack + [name], new_env))
|
|
else:
|
|
out.append((micro, list(stack)))
|
|
return out
|
|
|
|
expanded = expand(self.prog, [], {})
|
|
self.prog = {i: call for i, (call, _) in enumerate(expanded)}
|
|
self._macro_trace = {i: macros for i, (_, macros) in enumerate(expanded)}
|
|
|
|
# Build program structure with parameter info (no change to _macro_trace)
|
|
@override
|
|
def structure(self) -> Dict[int, Dict[str, Any]]:
|
|
struct: Dict[int, Dict[str, Any]] = {}
|
|
|
|
# extract macro calls with their argument values from original top-level prog
|
|
top_calls: Dict[str, Dict[str, Any]] = {}
|
|
for _, micro in sorted(self.initial_macros.get("__top__", self.prog).items()):
|
|
name, args = self.decode(micro)
|
|
if name in self.macros:
|
|
macro = self.macros[name]
|
|
params = macro.get("args") or []
|
|
if args and params:
|
|
top_calls[name] = dict(zip(params, args))
|
|
|
|
# build structure
|
|
for i, micro in sorted(self.prog.items()):
|
|
entry = {"micro": micro, "macros": []}
|
|
for macro_name in self._macro_trace.get(i, []):
|
|
macro_info = {"name": macro_name, "params": {}}
|
|
if macro_name in self.macros:
|
|
param_names = self.macros[macro_name].get("args") or []
|
|
macro_info["params"] = {
|
|
p: top_calls.get(macro_name, {}).get(p, None)
|
|
for p in param_names
|
|
}
|
|
entry["macros"].append(macro_info)
|
|
struct[i] = entry
|
|
return struct
|