Source code for pyab_experiment.codegen.python.python_generator

"""Module that translates an Abstract Syntax Tree (AST) into executable Python code.

This module contains the PythonCodeGen class which traverses an ExperimentAST and
generates corresponding Python code for experiment variant selection. The generated
code includes conditional logic, group assignments, and deterministic choice functions.
"""

from pyab_experiment.language.grammar import (
    BooleanOperatorEnum,
    ConditionalType,
    ExperimentAST,
    ExperimentConditional,
    ExperimentGroup,
    Identifier,
    LogicalOperatorEnum,
    RecursivePredicate,
    TerminalPredicate,
)


[docs] class PythonCodeGen: """Generates Python code from an ExperimentAST representation. This class maintains state during AST traversal (like indentation level and variable tracking)and provides methods to generate formatted Python code for experiment variant selection. Args: experiment_ast (ExperimentAST): The AST to translate into Python code indentation_char (str, optional): Character used for indentation. Defaults to "\t" expose_experiment_variant_function (bool, optional): Whether to expose the variant selection function at the root level. Defaults to True Attributes: _local_vars (set): Tracks local variables used in the generated code _conditional_ids (set): Tracks conditional variables referenced in predicates _indent_depth (int): Current indentation level during code generation """ def __init__( self, experiment_ast: ExperimentAST, indentation_char: str = "\t", expose_experiment_variant_function: bool = True, ): self._experiment_ast = experiment_ast self._local_vars = set() self._conditional_ids = set() # to save conditional variables seen self._indentation_char = indentation_char self._newline = "\n" self._indent_depth = 0 self._expose_fn = expose_experiment_variant_function @property
[docs] def local_vars(self): return sorted(self._local_vars)
@property
[docs] def conditional_ids(self): return sorted(self._conditional_ids)
[docs] def render_topline(self) -> str: """ renders the topline of our python code i.e. imports, and top line comment """ return ( f"from functools import partial{self._newline}" "from pyab_experiment.codegen.python.custom_exceptions" f" import ExperimentConditionalFailedError{self._newline}" "from pyab_experiment.binning.binning import " f"deterministic_choice{self._newline}" f"{self._newline}#*******AUTOGENERATED DO NOT MODIFY " f"***********{self._newline}{self._newline}" )
[docs] def indent(self) -> str: """ helper function that renders scope sensitive indentation using an internal state to keep track the position in the AST traversal """ return "".join([self._indentation_char * self._indent_depth])
[docs] def generate(self) -> str: """ main method. Does a DFS on the AST rendering python code as it traverses the nodes """ composite_key = self.generate_key_definition() # either we define a function inside the main fn, or at the root if self._expose_fn: self._indent_depth = 1 else: self._indent_depth = 2 variant_fn_body = self._generate_conditionals(self._experiment_ast.conditions) variable_assignment = ", ".join([f"{id}={id}" for id in self.conditional_ids]) # add after variable_assignment variant_fn_body += f"{self._newline}{self._generate_exception()}{self._newline}" self._indent_depth -= 1 variant_fn_signature = ( f"{self.indent()}def choose_experiment_variant" f"({', '.join(self.conditional_ids)}): {self._newline}" ) self._indent_depth = 1 function_call = ( f"{self.indent()}return choose_experiment_variant" f"({variable_assignment})({composite_key}){self._newline}" ) fn_defn = ( f"def {self._experiment_ast.id}" f"({', '.join(self.local_vars+self.conditional_ids+['**kwargs'])}): " f"{self._newline}" ) if self._expose_fn: return ( f"{self.render_topline()}{fn_defn}" f"{function_call}{variant_fn_signature}{variant_fn_body}" ) else: return ( f"{self.render_topline()}{fn_defn}{variant_fn_signature}" f"{variant_fn_body}{function_call}" )
[docs] def generate_key_definition(self) -> str: """Generate the composite hash key used for deterministic experiment group assignment. This method combines the experiment's salt (if provided) with the string representation of splitting fields to create a composite key. This key is used by the deterministic_choice function to consistently assign users to experiment groups. The composite key is constructed as follows: 1. If a salt exists, it's used as a prefix 2. If splitting fields exist, their string representations are concatenated 3. If neither exists, returns "None" Example outputs: - With salt="exp1" and fields=[user_id, country]: "'exp1'+''.join(map(str, [user_id, country]))" - With only fields=[device_id]: "''''.join(map(str, [device_id]))" - With no salt or fields: "None" Returns: str: A Python expression that evaluates to the composite key used for group assignment. Side Effects: - Adds any splitting field variables to self._local_vars """ salt_def = ( f"'{self._experiment_ast.salt}'" if self._experiment_ast.salt is not None else "''" ) fields_def = "" if self._experiment_ast.splitting_fields: for var in self._experiment_ast.splitting_fields: self._local_vars.add(var) fields_def = f"''.join(map(str, [{', '.join(self.local_vars)}]))" if len(fields_def) == 0: composite_key = "None" else: composite_key = f"{salt_def}+{fields_def}" return composite_key
[docs] def _generate_exception(self) -> str: return f"{self.indent()}raise ExperimentConditionalFailedError()"
[docs] def _generate_conditionals( self, condition: ExperimentConditional | list[ExperimentGroup] ) -> str: """Generates Python code for conditional statements and group return functions. This method traverses the experiment's conditional structure and generates the corresponding Python code. It handles: 1. Conditional statements (if/elif/else) with their predicates and branches 2. Terminal group definitions that return partial functions for variant selection Args: condition: Either an ExperimentConditional for if/elif/else logic, or a list[ExperimentGroup] for terminal group definitions. - ExperimentConditional contains predicate and branch information - list[ExperimentGroup] contains group definitions and weights Returns: str: Generated Python code as a string, including: - For conditionals: if/elif/else statements with their predicates - For groups: partial function definitions for variant selection Raises: RuntimeError: If an unsupported condition type is provided Side Effects: - Adds conditional ids to self._conditional_ids Example generated code: # For conditionals: if (age >= 18): return partial(deterministic_choice, population=['A', 'B'], weights=[1, 2]) else: return partial(deterministic_choice, population=['C', 'D'], weights=[1, 1]) # For direct group definitions: return partial(deterministic_choice, population=['A', 'B'], weights=[1, 1]) """ match condition: case ExperimentConditional(): predicate = self._generate_predicate(condition.predicate) self._indent_depth += 1 true_branch_stmt = self._generate_conditionals(condition.true_branch) self._indent_depth -= 1 false_branch_stmt = ( self._generate_conditionals(condition.false_branch) if condition.false_branch is not None else "" ) match condition.conditional_type: case ConditionalType.IF: return ( f"{self.indent()}if {predicate}: " f"{self._newline}{true_branch_stmt}{false_branch_stmt}" ) case ConditionalType.ELIF: return ( f"{self.indent()}elif {predicate}: " f"{self._newline}{true_branch_stmt}{false_branch_stmt}" ) case ConditionalType.ELSE: return ( f"{self.indent()}else: " f"{self._newline}{true_branch_stmt}{false_branch_stmt}" ) case [*_]: statement = self._generate_group_return_statement(condition) return statement case _: raise RuntimeError( f"wrong type passed to conditional gen {type(condition)}" )
[docs] def _generate_predicate( self, predicate: TerminalPredicate | RecursivePredicate | None ) -> str: """Generates Python code for predicate expressions in conditional statements. This method handles three types of predicates: 1. Terminal predicates: Simple comparisons between two terms (e.g., "age >= 18") 2. Recursive predicates: Complex boolean expressions combining multiple predicates 3. None: Returns an empty string (used in ELSE conditions) Args: predicate: The predicate to convert to Python code. Can be: - TerminalPredicate: For simple comparisons - RecursivePredicate: For complex boolean expressions - None: For else conditions Returns: str: A Python expression string representing the predicate logic. Raises: RuntimeError: If an unsupported predicate type is provided. Side Effects: - Adds conditional ids to self._conditional_ids as it encounters them """ match predicate: case TerminalPredicate(): l_term = self._generate_term(predicate.left_term) r_term = self._generate_term(predicate.right_term) operator = self._generate_op(predicate.logical_operator) return f"({l_term} {operator} {r_term})" case RecursivePredicate(): l_pred = self._generate_predicate(predicate.left_predicate) r_pred = self._generate_predicate(predicate.right_predicate) operator = self._generate_op(predicate.boolean_operator) if ( predicate.boolean_operator == BooleanOperatorEnum.NOT ): # special case return f"({operator} {l_pred})" return f"({l_pred} {operator} {r_pred})" case None: return "" case _: raise RuntimeError( f"wrong type passed to predicate gen {type(predicate)}" )
[docs] def _generate_term(self, term: float | int | str | tuple | Identifier) -> str: """renders a term""" match term: case Identifier(name=identifier_name): self._conditional_ids.add(identifier_name) return identifier_name case str(): return f"'{term}'" case _: return term
[docs] def _generate_op(self, op: LogicalOperatorEnum | BooleanOperatorEnum) -> str: """renders legal python operations""" match op: case LogicalOperatorEnum.EQ: return "==" case LogicalOperatorEnum.NE: return "!=" case LogicalOperatorEnum.GT: return ">" case LogicalOperatorEnum.GE: return ">=" case LogicalOperatorEnum.LT: return "<" case LogicalOperatorEnum.LE: return "<=" case LogicalOperatorEnum.NOT_IN: return "not in" case LogicalOperatorEnum.IN: return "in" case BooleanOperatorEnum.NOT: return "not" case BooleanOperatorEnum.AND: return "and" case BooleanOperatorEnum.OR: return "or" case _: raise RuntimeError(f"OperatorEnum not matched: {op}")
[docs] def _generate_group_return_statement( self, group_statement: list[ExperimentGroup] ) -> str: """unwrap the experiment group, into a partial function call that applies the splitter logic""" population_list = str([group.group_definition for group in group_statement]) weight_list = str([group.group_weight for group in group_statement]) return ( f"{self.indent()}return partial(deterministic_choice, population=" f"{population_list}, weights={weight_list}){self._newline}" )