"""Module that translates an Abstract Syntax Tree (AST) into executable Python code.
This module contains the PythonCodeGen class which traverses an ExperimentAST and
generates corresponding Python code for experiment variant selection. The generated
code includes conditional logic, group assignments, and deterministic choice functions.
"""
from pyab_experiment.language.grammar import (
BooleanOperatorEnum,
ConditionalType,
ExperimentAST,
ExperimentConditional,
ExperimentGroup,
Identifier,
LogicalOperatorEnum,
RecursivePredicate,
TerminalPredicate,
)
[docs]
class PythonCodeGen:
"""Generates Python code from an ExperimentAST representation.
This class maintains state during AST traversal (like indentation level
and variable tracking)and provides methods to generate formatted Python
code for experiment variant selection.
Args:
experiment_ast (ExperimentAST): The AST to translate into Python code
indentation_char (str, optional): Character used for indentation.
Defaults to "\t" expose_experiment_variant_function (bool, optional):
Whether to expose the variant selection function at the root level.
Defaults to True
Attributes:
_local_vars (set): Tracks local variables used in the generated code
_conditional_ids (set): Tracks conditional variables referenced in predicates
_indent_depth (int): Current indentation level during code generation
"""
def __init__(
self,
experiment_ast: ExperimentAST,
indentation_char: str = "\t",
expose_experiment_variant_function: bool = True,
):
self._experiment_ast = experiment_ast
self._local_vars = set()
self._conditional_ids = set() # to save conditional variables seen
self._indentation_char = indentation_char
self._newline = "\n"
self._indent_depth = 0
self._expose_fn = expose_experiment_variant_function
@property
[docs]
def local_vars(self):
return sorted(self._local_vars)
@property
[docs]
def conditional_ids(self):
return sorted(self._conditional_ids)
[docs]
def render_topline(self) -> str:
"""
renders the topline of our python code
i.e. imports, and top line comment
"""
return (
f"from functools import partial{self._newline}"
"from pyab_experiment.codegen.python.custom_exceptions"
f" import ExperimentConditionalFailedError{self._newline}"
"from pyab_experiment.binning.binning import "
f"deterministic_choice{self._newline}"
f"{self._newline}#*******AUTOGENERATED DO NOT MODIFY "
f"***********{self._newline}{self._newline}"
)
[docs]
def indent(self) -> str:
"""
helper function that renders scope sensitive indentation
using an internal state to keep track the position in the
AST traversal
"""
return "".join([self._indentation_char * self._indent_depth])
[docs]
def generate(self) -> str:
"""
main method. Does a DFS on the AST rendering python
code as it traverses the nodes
"""
composite_key = self.generate_key_definition()
# either we define a function inside the main fn, or at the root
if self._expose_fn:
self._indent_depth = 1
else:
self._indent_depth = 2
variant_fn_body = self._generate_conditionals(self._experiment_ast.conditions)
variable_assignment = ", ".join([f"{id}={id}" for id in self.conditional_ids])
# add after variable_assignment
variant_fn_body += f"{self._newline}{self._generate_exception()}{self._newline}"
self._indent_depth -= 1
variant_fn_signature = (
f"{self.indent()}def choose_experiment_variant"
f"({', '.join(self.conditional_ids)}): {self._newline}"
)
self._indent_depth = 1
function_call = (
f"{self.indent()}return choose_experiment_variant"
f"({variable_assignment})({composite_key}){self._newline}"
)
fn_defn = (
f"def {self._experiment_ast.id}"
f"({', '.join(self.local_vars+self.conditional_ids+['**kwargs'])}): "
f"{self._newline}"
)
if self._expose_fn:
return (
f"{self.render_topline()}{fn_defn}"
f"{function_call}{variant_fn_signature}{variant_fn_body}"
)
else:
return (
f"{self.render_topline()}{fn_defn}{variant_fn_signature}"
f"{variant_fn_body}{function_call}"
)
[docs]
def generate_key_definition(self) -> str:
"""Generate the composite hash key used for deterministic experiment
group assignment.
This method combines the experiment's salt (if provided) with the string
representation of splitting fields to create a composite key.
This key is used by the deterministic_choice function to consistently
assign users to experiment groups.
The composite key is constructed as follows:
1. If a salt exists, it's used as a prefix
2. If splitting fields exist, their string representations are concatenated
3. If neither exists, returns "None"
Example outputs:
- With salt="exp1" and fields=[user_id, country]:
"'exp1'+''.join(map(str, [user_id, country]))"
- With only fields=[device_id]:
"''''.join(map(str, [device_id]))"
- With no salt or fields:
"None"
Returns:
str: A Python expression that evaluates to the composite key
used for group assignment.
Side Effects:
- Adds any splitting field variables to self._local_vars
"""
salt_def = (
f"'{self._experiment_ast.salt}'"
if self._experiment_ast.salt is not None
else "''"
)
fields_def = ""
if self._experiment_ast.splitting_fields:
for var in self._experiment_ast.splitting_fields:
self._local_vars.add(var)
fields_def = f"''.join(map(str, [{', '.join(self.local_vars)}]))"
if len(fields_def) == 0:
composite_key = "None"
else:
composite_key = f"{salt_def}+{fields_def}"
return composite_key
[docs]
def _generate_exception(self) -> str:
return f"{self.indent()}raise ExperimentConditionalFailedError()"
[docs]
def _generate_conditionals(
self, condition: ExperimentConditional | list[ExperimentGroup]
) -> str:
"""Generates Python code for conditional statements and group return functions.
This method traverses the experiment's conditional structure and generates the
corresponding Python code. It handles:
1. Conditional statements (if/elif/else) with their predicates and branches
2. Terminal group definitions that return partial functions for
variant selection
Args:
condition: Either an ExperimentConditional for if/elif/else logic,
or a list[ExperimentGroup] for terminal group definitions.
- ExperimentConditional contains predicate and branch information
- list[ExperimentGroup] contains group definitions and weights
Returns:
str: Generated Python code as a string, including:
- For conditionals: if/elif/else statements with their predicates
- For groups: partial function definitions for variant selection
Raises:
RuntimeError: If an unsupported condition type is provided
Side Effects:
- Adds conditional ids to self._conditional_ids
Example generated code:
# For conditionals:
if (age >= 18):
return partial(deterministic_choice, population=['A', 'B'],
weights=[1, 2])
else:
return partial(deterministic_choice, population=['C', 'D'],
weights=[1, 1])
# For direct group definitions:
return partial(deterministic_choice, population=['A', 'B'],
weights=[1, 1])
"""
match condition:
case ExperimentConditional():
predicate = self._generate_predicate(condition.predicate)
self._indent_depth += 1
true_branch_stmt = self._generate_conditionals(condition.true_branch)
self._indent_depth -= 1
false_branch_stmt = (
self._generate_conditionals(condition.false_branch)
if condition.false_branch is not None
else ""
)
match condition.conditional_type:
case ConditionalType.IF:
return (
f"{self.indent()}if {predicate}: "
f"{self._newline}{true_branch_stmt}{false_branch_stmt}"
)
case ConditionalType.ELIF:
return (
f"{self.indent()}elif {predicate}: "
f"{self._newline}{true_branch_stmt}{false_branch_stmt}"
)
case ConditionalType.ELSE:
return (
f"{self.indent()}else: "
f"{self._newline}{true_branch_stmt}{false_branch_stmt}"
)
case [*_]:
statement = self._generate_group_return_statement(condition)
return statement
case _:
raise RuntimeError(
f"wrong type passed to conditional gen {type(condition)}"
)
[docs]
def _generate_predicate(
self, predicate: TerminalPredicate | RecursivePredicate | None
) -> str:
"""Generates Python code for predicate expressions in conditional statements.
This method handles three types of predicates:
1. Terminal predicates: Simple comparisons between two terms (e.g., "age >= 18")
2. Recursive predicates: Complex boolean expressions combining
multiple predicates
3. None: Returns an empty string (used in ELSE conditions)
Args:
predicate: The predicate to convert to Python code. Can be:
- TerminalPredicate: For simple comparisons
- RecursivePredicate: For complex boolean expressions
- None: For else conditions
Returns:
str: A Python expression string representing the predicate logic.
Raises:
RuntimeError: If an unsupported predicate type is provided.
Side Effects:
- Adds conditional ids to self._conditional_ids as it encounters them
"""
match predicate:
case TerminalPredicate():
l_term = self._generate_term(predicate.left_term)
r_term = self._generate_term(predicate.right_term)
operator = self._generate_op(predicate.logical_operator)
return f"({l_term} {operator} {r_term})"
case RecursivePredicate():
l_pred = self._generate_predicate(predicate.left_predicate)
r_pred = self._generate_predicate(predicate.right_predicate)
operator = self._generate_op(predicate.boolean_operator)
if (
predicate.boolean_operator == BooleanOperatorEnum.NOT
): # special case
return f"({operator} {l_pred})"
return f"({l_pred} {operator} {r_pred})"
case None:
return ""
case _:
raise RuntimeError(
f"wrong type passed to predicate gen {type(predicate)}"
)
[docs]
def _generate_term(self, term: float | int | str | tuple | Identifier) -> str:
"""renders a term"""
match term:
case Identifier(name=identifier_name):
self._conditional_ids.add(identifier_name)
return identifier_name
case str():
return f"'{term}'"
case _:
return term
[docs]
def _generate_op(self, op: LogicalOperatorEnum | BooleanOperatorEnum) -> str:
"""renders legal python operations"""
match op:
case LogicalOperatorEnum.EQ:
return "=="
case LogicalOperatorEnum.NE:
return "!="
case LogicalOperatorEnum.GT:
return ">"
case LogicalOperatorEnum.GE:
return ">="
case LogicalOperatorEnum.LT:
return "<"
case LogicalOperatorEnum.LE:
return "<="
case LogicalOperatorEnum.NOT_IN:
return "not in"
case LogicalOperatorEnum.IN:
return "in"
case BooleanOperatorEnum.NOT:
return "not"
case BooleanOperatorEnum.AND:
return "and"
case BooleanOperatorEnum.OR:
return "or"
case _:
raise RuntimeError(f"OperatorEnum not matched: {op}")
[docs]
def _generate_group_return_statement(
self, group_statement: list[ExperimentGroup]
) -> str:
"""unwrap the experiment group, into a partial function call
that applies the splitter logic"""
population_list = str([group.group_definition for group in group_statement])
weight_list = str([group.group_weight for group in group_statement])
return (
f"{self.indent()}return partial(deterministic_choice, population="
f"{population_list}, weights={weight_list}){self._newline}"
)