Source code for pqg.query_builder

import random
import typing as t

from .entity import Entity, PropertyDate, PropertyEnum, PropertyFloat, PropertyInt, PropertyString
from .group_by_aggregation import GroupByAggregation
from .merge import Merge
from .operation import Operation
from .projection import Projection
from .query import Query
from .query_structure import QueryStructure
from .schema import Schema
from .selection import Selection


[docs] class QueryBuilder: """ A builder class for generating random, valid pandas DataFrame queries. This class constructs queries using a combination of operations like selection, projection, merge, and group by aggregation. It ensures that generated queries are valid by tracking available columns and maintaining referential integrity. Attributes: schema (Schema): The database schema containing entity definitions. query_structure (QueryStructure): Configuration parameters for query generation. multi_line (bool): Whether to format queries across multiple lines. operations (List[Operation]): List of operations to be applied in the query. entity_name (str): Name of the current entity being queried. entity (Entity): The current entity's schema definition. current_columns (Set[str]): Set of columns currently available for operations. required_columns (Set[str]): Columns that must be preserved (e.g., join keys). """ def __init__(self, schema: Schema, query_structure: QueryStructure, multi_line: bool): self.schema: Schema = schema self.query_structure: QueryStructure = query_structure self.multi_line: bool = multi_line self.entity: Entity = random.choice(list(self.schema.entities)) self.columns: t.Set[str] = set(self.entity.properties.keys()) self.required_columns: t.Set[str] = set() self.merge_entities: t.Set[str] = {self.entity.name} self.max_merges = self.query_structure.max_merges self.operations: t.List[Operation] = []
[docs] def build(self) -> Query: """ Build a complete query by randomly combining different operations. The method generates operations based on the following probabilities: - 50% chance of adding a selection (WHERE clause) - 50% chance of adding a projection (SELECT columns) - 0 to max_merges joins with other tables - 50% chance of adding a group by if grouping is enabled Each operation is validated to ensure it uses only available columns and maintains referential integrity. Returns: Query: A complete, valid query object with all generated operations. """ if ( self.columns and self.query_structure.max_selection_conditions > 0 and random.random() < self.query_structure.selection_probability ): self.operations.append(self._generate_selection()) if ( self.columns and self.query_structure.max_projection_columns > 0 and random.random() < self.query_structure.projection_probability ): self.operations.append(self._generate_projection()) num_merges = random.randint(0, self.max_merges) for _ in range(num_merges): try: self.operations.append(self._generate_merge(num_merges)) except ValueError: break if ( self.columns and self.query_structure.max_groupby_columns > 0 and random.random() < self.query_structure.groupby_aggregation_probability ): self.operations.append(self._generate_group_by_aggregation()) return Query(self.entity.name, self.operations, self.multi_line, self.columns)
def _generate_selection(self) -> Operation: """ Generate a WHERE clause for filtering rows. Creates conditions for filtering data based on column types: - Numeric columns: Comparison operators (==, !=, <, <=, >, >=) - String columns: Equality, inequality, and string operations - Enum columns: Value matching and set membership tests - Date columns: Date comparison operations Conditions are combined using AND (&) or OR (|) operators. The number of conditions is bounded by max_selection_conditions configuration. Returns: Operation: A Selection operation with the generated conditions. """ num_conditions = random.randint( 1, min(self.query_structure.max_selection_conditions, len(self.columns)) ) conditions, available_columns = [], list(self.columns) for i in range(num_conditions): column = random.choice(available_columns) property, next_op = ( self.entity.properties[column], random.choice(['&', '|']) if i < num_conditions - 1 else None, ) match property: case PropertyInt(minimum, maximum) | PropertyFloat(minimum, maximum): op = random.choice(['==', '!=', '<', '<=', '>', '>=']) value = random.uniform(minimum, maximum) if isinstance(property, PropertyInt): value = int(value) conditions.append((f"'{column}'", op, value, next_op)) case PropertyString(starting_character): op = random.choice(['==', '!=', '.str.startswith']) value = random.choice(starting_character) quoted_value = f"'{value}'" if "'" not in value else f'"{value}"' conditions.append((f"'{column}'", op, quoted_value, next_op)) case PropertyEnum(values): op = random.choice(['==', '!=', '.isin']) if op == '.isin': selected_values = random.sample(values, random.randint(1, len(values))) quoted_values = [f"'{v}'" if "'" not in v else f'"{v}"' for v in selected_values] value = f"[{', '.join(quoted_values)}]" else: value = random.choice(values) value = f"'{value}'" if "'" not in value else f'"{value}"' conditions.append((f"'{column}'", op, value, next_op)) case PropertyDate(minimum, maximum): op = random.choice(['==', '!=', '<', '<=', '>', '>=']) value = f"'{random.choice([minimum, maximum]).isoformat()}'" conditions.append((f"'{column}'", op, value, next_op)) return Selection(conditions) def _generate_projection(self) -> Operation: """ Generate a SELECT clause for choosing columns. Randomly selects a subset of available columns while ensuring that required columns (like join keys) are always included. The number of selected columns is bounded by max_projection_columns configuration. The operation updates the available columns for subsequent operations while maintaining required columns for joins and other operations. Returns: Operation: A Projection operation with the selected columns. """ available_for_projection = self.columns - self.required_columns if not available_for_projection: return Projection(list(self.columns)) to_project = random.randint( 1, min(self.query_structure.max_projection_columns, len(available_for_projection)) ) columns = set(random.sample(list(available_for_projection), to_project)) | self.required_columns self.columns = columns return Projection(list(columns)) def _generate_merge(self, num_merges: int) -> Operation: """ Generate a JOIN operation with another table. Creates a join operation based on foreign key relationships defined in the schema. Ensures join columns are preserved in projections on both sides of the join. The method: 1. Identifies possible join relationships 2. Randomly selects a valid join path 3. Creates a new query for the right side 4. Ensures join columns are preserved Returns: Operation: A Merge operation with the join conditions. Raises: ValueError: If no valid join relationships are available. """ possible_right_entities = [] for local_col, [foreign_col, foreign_table] in self.entity.foreign_keys.items(): if local_col in self.columns and foreign_table not in self.merge_entities: possible_right_entities.append((local_col, foreign_col, foreign_table)) if not possible_right_entities: raise ValueError('No valid entities for merge') left_on, right_on, right_entity_name = random.choice(possible_right_entities) right_query_structure = QueryStructure( groupby_aggregation_probability=0, max_groupby_columns=0, max_merges=self.max_merges - num_merges, max_projection_columns=self.query_structure.max_projection_columns, max_selection_conditions=self.query_structure.max_selection_conditions, projection_probability=self.query_structure.projection_probability, selection_probability=self.query_structure.selection_probability, ) right_builder = QueryBuilder(self.schema, right_query_structure, self.multi_line) right_builder.entity = self.schema[right_entity_name] right_builder.columns = set(right_builder.entity.properties.keys()) right_builder.required_columns.add(right_on) right_query = right_builder.build() self.columns = self.columns.union(right_query.columns) self.merge_entities = self.merge_entities.union(right_query.merge_entities) self.max_merges = self.max_merges - right_query.merge_count def format_join_columns(columns: str | t.List[str]) -> str: return ( f"[{', '.join(f"'{col}'" for col in columns)}]" if isinstance(columns, list) else f"'{columns}'" ) return Merge( right=right_query, left_on=format_join_columns(left_on), right_on=format_join_columns(right_on), ) def _generate_group_by_aggregation(self) -> Operation: """ Generate a GROUP BY clause with aggregation. Creates a grouping operation that: 1. Randomly selects columns to group by 2. Chooses an aggregation function (mean, sum, min, max, count) 3. Ensures numeric_only parameter is set appropriately The number of grouping columns is bounded by max_groupby_columns configuration and available columns. Returns: Operation: A GroupByAggregation operation with the grouping configuration. """ group_columns = random.sample( list(self.columns), random.randint(1, min(self.query_structure.max_groupby_columns, len(self.columns))), ) agg_function = random.choice(['mean', 'sum', 'min', 'max', 'count']) return GroupByAggregation(group_columns, agg_function)