import multiprocessing as mp
import os
import statistics as stats
import typing as t
from collections import Counter
from dataclasses import dataclass, field
from functools import partial
import pandas as pd
from tqdm import tqdm
from .arguments import QueryFilter
from .group_by_aggregation import GroupByAggregation
from .merge import Merge
from .projection import Projection
from .query import Query
from .query_structure import QueryStructure
from .selection import Selection
QueryResult = t.Tuple[t.Optional[t.Union[pd.DataFrame, pd.Series]], t.Optional[str]]
[docs]
@dataclass
class ExecutionStatistics:
"""
Statistics about query execution results.
This class tracks success/failure rates and categorizes execution outcomes
to provide insights into query generation quality.
Attributes:
successful (int): Number of queries that executed without error
failed (int): Number of queries that raised an error
non_empty (int): Number of queries returning data (rows/values)
empty (int): Number of queries returning empty results
errors (Counter[str]): Count of each error type encountered
Example:
stats = ExecutionStatistics()
stats.successful = 95
stats.failed = 5
stats.errors["KeyError"] = 3
print(stats) # Shows execution summary with percentages
"""
successful: int = 0
failed: int = 0
non_empty: int = 0
empty: int = 0
errors: Counter[str] = field(default_factory=Counter)
def __str__(self) -> str:
total = self.successful + self.failed
if total == 0:
return ''
success_rate, empty_rate = (
(self.successful / total) * 100,
(self.empty / total) * 100 if total > 0 else 0,
)
lines = [
f'Execution Results (n = {total}):',
f' Success Rate: {success_rate:5.1f}% ({self.successful} / {total})',
f' Queries with Empty Result Sets: {empty_rate:4.1f}% ({self.empty} / {total})',
]
if self.errors:
lines.extend(
['', 'Errors:', *[f' {error}: {count}' for error, count in self.errors.most_common()]]
)
return '\n'.join(lines)
[docs]
@dataclass
class QueryStatistics:
"""
Statistics comparing actual query characteristics against target parameters.
Analyzes how closely generated queries match the desired structure parameters
and collects distributions of various operation characteristics.
Attributes:
query_structure: Target parameters for query generation
total_queries: Total number of queries analyzed
queries_with_selection: Number of queries containing selection operations
queries_with_projection: Number of queries containing projection operations
queries_with_groupby: Number of queries containing groupby operations
queries_with_merge: Number of queries containing merge operations
selection_conditions: List of condition counts from each selection operation
projection_columns: List of column counts from each projection operation
groupby_columns: List of column counts from each groupby operation
merge_count: List of merge counts from each query
execution_results: Statistics about query execution outcomes
"""
query_structure: QueryStructure
total_queries: int = 0
queries_with_selection: int = 0
queries_with_projection: int = 0
queries_with_groupby: int = 0
queries_with_merge: int = 0
selection_conditions: t.List[int] = field(default_factory=list)
projection_columns: t.List[int] = field(default_factory=list)
groupby_columns: t.List[int] = field(default_factory=list)
merge_count: t.List[int] = field(default_factory=list)
execution_results: ExecutionStatistics = field(default_factory=ExecutionStatistics)
@staticmethod
def _safe_stats(values: t.List[int]) -> tuple[float, float, int]:
"""Calculate mean, standard deviation, and max for a list of values."""
if not values:
return 0.0, 0.0, 0
return (stats.mean(values), stats.stdev(values) if len(values) > 1 else 0.0, max(values))
@staticmethod
def _format_frequency(label: str, actual: float, target: t.Optional[float] = None) -> str:
"""Format a frequency line."""
if target is None:
return f' {label} {actual:.1f}%'
return f' {label} {actual:.1f}% vs {target*100:.1f}%'
@staticmethod
def _format_count(label: str, mean: float, std: float, max_val: int, limit: int) -> str:
"""Format a count line."""
return f' {label} {mean:.1f} ± {std:.1f} | {max_val} vs {limit}'
def __str__(self) -> str:
if self.total_queries == 0:
return ''
probabilities = (
self.queries_with_selection / self.total_queries * 100,
self.queries_with_projection / self.total_queries * 100,
self.queries_with_merge / self.total_queries * 100,
self.queries_with_groupby / self.total_queries * 100,
)
lines = [
f'Query Statistics (n = {self.total_queries})',
'',
'Operation Probabilities (actual vs target):',
self._format_frequency(
'Selection:', probabilities[0], self.query_structure.selection_probability
),
self._format_frequency(
'Projection:', probabilities[1], self.query_structure.projection_probability
),
self._format_frequency(
'GroupBy:', probabilities[3], self.query_structure.groupby_aggregation_probability
),
self._format_frequency('Merge:', probabilities[2]),
'',
'Operation Counts (avg ± std | max vs limit):',
]
for data, label, limit in [
(
self.selection_conditions,
'Selection conditions:',
self.query_structure.max_selection_conditions,
),
(self.projection_columns, 'Projection columns:', self.query_structure.max_projection_columns),
(self.groupby_columns, 'GroupBy columns:', self.query_structure.max_groupby_columns),
(self.merge_count, 'Merges per query:', self.query_structure.max_merges),
]:
if data:
mean, std, max_val = self._safe_stats(data)
lines.append(self._format_count(label, mean, std, max_val, limit))
lines.extend(['', str(self.execution_results)])
return '\n'.join(lines)
[docs]
class QueryPool:
"""
Manages a collection of database queries with execution and analysis capabilities.
Provides functionality for executing queries in parallel, filtering based on results,
computing statistics, and managing query persistence. The pool maintains execution
results to avoid redundant computation when performing multiple operations.
Attributes:
_queries: List of Query objects in the pool
_query_structure: Parameters controlling query generation
_sample_data: Dictionary mapping entity names to sample DataFrames
_results: Cached query execution results (DataFrame/Series, error message)
_with_status: Whether to display progress bars during operations
"""
def __init__(
self,
queries: t.List[Query],
query_structure: QueryStructure,
sample_data: t.Dict[str, pd.DataFrame],
multi_processing: bool = True,
with_status: bool = False,
):
self._queries = queries
self._query_structure = query_structure
self._sample_data = sample_data
self._results: t.List[QueryResult] = []
self._multi_processing = multi_processing
self._with_status = with_status
[docs]
def __len__(self) -> int:
"""Return the number of queries in the pool."""
return len(self._queries)
[docs]
def __iter__(self) -> t.Iterator[Query]:
"""Iterate over the queries in the pool."""
return iter(self._queries)
@staticmethod
def _execute_multi_line_query(
query: Query, sample_data: t.Dict[str, pd.DataFrame]
) -> QueryResult:
"""Execute a multi-line query by executing each line sequentially."""
try:
local_vars = sample_data.copy()
for line in str(query).split('\n'):
df_name, expression = line.split(' = ', 1)
result = eval(expression, {}, local_vars)
local_vars[df_name] = result
last_df = max(k for k in local_vars.keys() if k.startswith('df'))
return local_vars[last_df], None
except Exception as e:
return None, f'{type(e).__name__}: {str(e)}'
@staticmethod
def _execute_single_query(query: Query, sample_data: t.Dict[str, pd.DataFrame]) -> QueryResult:
"""Execute a single query and handle any errors."""
try:
if query.multi_line:
return QueryPool._execute_multi_line_query(query, sample_data)
result = pd.eval(str(query), local_dict=sample_data)
if isinstance(result, (pd.DataFrame, pd.Series)):
return result, None
return None, f'Result was not a DataFrame or Series: {type(result)}'
except Exception as e:
return None, f'{type(e).__name__}: {str(e)}'
[docs]
def execute(
self,
force_execute: bool = False,
num_processes: t.Optional[int] = None,
) -> t.List[QueryResult]:
"""
Execute all queries against the sample data, either in parallel or sequentially.
Evaluates queries using either multiprocessing for parallel execution or
sequential processing. Results are cached to avoid re-execution unless
explicitly requested.
Args:
force_execute: If True, re-execute all queries even if results exist
num_processes:
Number of parallel processes to use. Defaults to CPU count.
Only used when multi_processing is True.
Returns:
List of tuples containing (result, error) for each query
"""
if len(self._queries) == 0:
return []
if len(self._results) > 0 and not force_execute:
return self._results
f = partial(self._execute_single_query, sample_data=self._sample_data)
if self._multi_processing:
ctx = mp.get_context('fork')
with ctx.Pool(num_processes) as pool:
iterator = pool.imap(f, self._queries)
if self._with_status:
iterator = tqdm(
iterator, total=len(self._queries), desc='Executing queries', unit='query'
)
self._results = list(iterator)
else:
if self._with_status:
iterator = tqdm(self._queries, desc='Executing queries', unit='query')
else:
iterator = self._queries
self._results = [f(query) for query in iterator]
return self._results
[docs]
def filter(
self,
filter_type: QueryFilter,
force_execute: bool = False,
) -> None:
"""
Filter queries based on their execution results.
Modifies the query pool in-place to keep only queries matching the filter
criteria. Executes queries if results don't exist.
Args:
filter_type: Criteria for keeping queries (NON_EMPTY, EMPTY, etc.)
force_execute: If True, re-execute queries before filtering
"""
if not self._results or force_execute:
self._results = self.execute()
filtered_queries, filtered_results = [], []
for query, result_tuple in zip(self._queries, self._results):
result, error = result_tuple
should_keep = False
match filter_type:
case QueryFilter.NON_EMPTY:
should_keep = result is not None and (
(isinstance(result, pd.DataFrame) and not result.empty)
or (isinstance(result, pd.Series) and result.size > 0)
)
case QueryFilter.EMPTY:
should_keep = result is not None and (
(isinstance(result, pd.DataFrame) and result.empty)
or (isinstance(result, pd.Series) and result.size == 0)
)
case QueryFilter.HAS_ERROR:
should_keep = error is not None
case QueryFilter.WITHOUT_ERROR:
should_keep = error is None
if should_keep:
filtered_queries.append(query)
filtered_results.append(result_tuple)
self._queries, self._results = filtered_queries, filtered_results
[docs]
def items(self) -> t.Iterator[tuple[Query, QueryResult]]:
"""
Iterate over query-result pairs.
Each iteration yields a tuple containing a query and its execution result.
If results haven't been computed yet, executes the queries first.
Yields:
tuple[Query, QueryResult]: Pairs of (query, (result, error))
"""
if not self._results:
self.execute()
return zip(self._queries, self._results)
[docs]
def results(self) -> t.Iterator[QueryResult]:
"""
Iterate over query results.
If results haven't been computed yet, executes the queries first.
Yields:
QueryResult: Pairs of (result, error) for each query
"""
if not self._results:
self.execute()
return iter(self._results)
[docs]
def statistics(self, force_execute: bool = False) -> QueryStatistics:
"""
Generate comprehensive statistics about the query pool.
Analyzes query characteristics and execution results to measure how well
the generated queries match the target parameters.
Args:
force_execute: If True, re-execute queries before analysis
Returns:
Statistics comparing actual vs. target characteristics
"""
if not self._results or force_execute:
self._results = self.execute()
statistics = QueryStatistics(query_structure=self._query_structure)
statistics.total_queries = len(self._queries)
for query in self._queries:
has_selection = has_projection = has_groupby = False
selection_count = projection_count = groupby_count = 0
for op in query.operations:
match op:
case Selection(conditions):
has_selection = True
selection_count = len(conditions)
case Projection(columns):
has_projection = True
projection_count = len(columns)
case GroupByAggregation(columns, _):
has_groupby = True
groupby_count = len(columns)
case Merge():
...
if has_selection:
statistics.queries_with_selection += 1
statistics.selection_conditions.append(selection_count)
if has_projection:
statistics.queries_with_projection += 1
statistics.projection_columns.append(projection_count)
if has_groupby:
statistics.queries_with_groupby += 1
statistics.groupby_columns.append(groupby_count)
merge_count = query.merge_count
if merge_count > 0:
statistics.queries_with_merge += 1
statistics.merge_count.append(merge_count)
if self._results:
for result, error in self._results:
if error is not None:
statistics.execution_results.failed += 1
statistics.execution_results.errors[error] += 1
else:
statistics.execution_results.successful += 1
if result is not None:
if (isinstance(result, pd.DataFrame) and not result.empty) or (
isinstance(result, pd.Series) and result.size > 0
):
statistics.execution_results.non_empty += 1
else:
statistics.execution_results.empty += 1
return statistics
[docs]
def save(self, output_file: str, create_dirs: bool = True) -> None:
"""
Save all queries to a file.
Each query is saved on a separate line in its string representation.
Empty queries are filtered out and whitespace is trimmed.
Args:
output_file: Path to the output file
create_dirs: If True, creates parent directories if needed
"""
if create_dirs and os.path.dirname(output_file):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
queries = filter(None, map(lambda q: str(q).strip(), self._queries))
with open(output_file, 'w+') as f:
f.write('\n\n'.join(queries))
[docs]
def sort(self) -> None:
"""
Sort queries by their complexity.
Orders queries based on their complexity score while maintaining the
association between queries and their execution results if they exist.
"""
if not self._results:
self._queries = sorted(self._queries)
else:
pairs = list(zip(self._queries, self._results))
pairs.sort(key=lambda x: x[0])
self._queries, self._results = map(list, zip(*pairs))