import typing as t
from .group_by_aggregation import GroupByAggregation
from .merge import Merge
from .operation import Operation
from .projection import Projection
from .selection import Selection
[docs]
class Query:
"""
Represents a complete database query with tracking for query complexity.
A query consists of a target entity and a sequence of operations to be
applied to that entity. Query complexity is determined primarily by the
number of merge operations and their nesting depth.
Attributes:
entity (str): The name of the target entity.
operations (List[Operation]): List of operations to apply.
multi_line (bool): Whether to format output across multiple lines.
columns (Set[str]): Columns available for operations.
"""
def __init__(
self,
entity: str,
operations: t.List[Operation],
multi_line: bool,
columns: t.Set[str],
):
self.entity = entity
self.operations = operations
self.multi_line = multi_line
self.columns = columns
[docs]
def __str__(self) -> str:
return (
self.format_multi_line()[0]
if self.multi_line
else f'{self.entity}{''.join(op.apply(self.entity) for op in self.operations)}'
)
[docs]
def __hash__(self) -> int:
"""Hash based on complexity and string representation."""
return hash((self.complexity, str(self)))
[docs]
def __eq__(self, other: object) -> bool:
"""Equality comparison based on complexity and string representation."""
if not isinstance(other, Query):
return NotImplemented
return (self.complexity, str(self)) == (other.complexity, str(other))
[docs]
def __lt__(self, other: object) -> bool:
"""Less than comparison based on complexity and string representation."""
if not isinstance(other, Query):
return NotImplemented
return (self.complexity, str(self)) < (other.complexity, str(other))
@property
def complexity(self) -> int:
"""
Calculate query complexity based on all operations and their details.
Complexity is determined by:
1. Base complexity: Total number of operations
2. Merge complexity:
- Each merge adds weight of 3 (more complex than other operations)
- Additional complexity from nested queries
3. Selection complexity: Number of conditions in each selection
4. Projection complexity: Number of columns being projected
5. GroupBy complexity: Number of grouping columns plus weight of aggregation
Returns:
int: Complexity score for the query
"""
def get_merge_complexity(op: Operation) -> int:
return (
3 + sum(get_operation_complexity(nested_op) for nested_op in op.right.operations)
if isinstance(op, Merge)
else 0
)
def get_operation_complexity(op: Operation) -> int:
if isinstance(op, Selection):
return 1 + len(op.conditions)
elif isinstance(op, Projection):
return 1 + len(op.columns)
elif isinstance(op, GroupByAggregation):
return 2 + len(op.group_by_columns)
elif isinstance(op, Merge):
return get_merge_complexity(op)
raise ValueError('Unsupported operation type')
base_complexity = len(self.operations)
operation_complexity = sum(get_operation_complexity(op) for op in self.operations)
return base_complexity + operation_complexity
@property
def merge_count(self) -> int:
"""
Count the total number of merge operations in the query, including nested merges.
Returns:
int: Total number of merge operations
"""
return sum(
1 + sum(1 for nested_op in op.right.operations if isinstance(nested_op, Merge))
if isinstance(op, Merge)
else 0
for op in self.operations
)
@property
def merge_entities(self) -> t.Set[str]:
"""
Get the set of all entities involved in this query, including
the base entity and all merged entities.
This property maintains a complete picture of table dependencies by tracking:
1. The base entity of the query
2. All entities that have been merged directly into this query
3. All entities that have been merged into sub-queries (nested merges)
The tracking helps prevent:
- Circular dependencies (e.g., orders → customers → orders)
- Redundant joins (e.g., merging the same table multiple times)
- Invalid join paths
Returns:
Set[str]:
A set of entity names (table names) that are part of this query's join graph.
Includes both the base entity and all merged entities.
"""
merged = {self.entity}
for op in self.operations:
if isinstance(op, Merge):
merged.update(op.entities)
return merged