sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 97 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 98 corresponding expression to avoid creating invalid column references. 99 """ 100 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 101 taken = set(expression.named_selects) 102 for select in expression.selects: 103 if not select.alias_or_name: 104 alias = find_new_name(taken, "_c") 105 select.replace(exp.alias_(select, alias)) 106 taken.add(alias) 107 108 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 109 alias_or_name = select.alias_or_name 110 identifier = select.args.get("alias") or select.this 111 if isinstance(identifier, exp.Identifier): 112 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 113 return alias_or_name 114 115 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 116 qualify_filters = expression.args["qualify"].pop().this 117 expression_by_alias = { 118 select.alias: select.this 119 for select in expression.selects 120 if isinstance(select, exp.Alias) 121 } 122 123 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 124 for select_candidate in qualify_filters.find_all(select_candidates): 125 if isinstance(select_candidate, exp.Window): 126 if expression_by_alias: 127 for column in select_candidate.find_all(exp.Column): 128 expr = expression_by_alias.get(column.name) 129 if expr: 130 column.replace(expr) 131 132 alias = find_new_name(expression.named_selects, "_w") 133 expression.select(exp.alias_(select_candidate, alias), copy=False) 134 column = exp.column(alias) 135 136 if isinstance(select_candidate.parent, exp.Qualify): 137 qualify_filters = column 138 else: 139 select_candidate.replace(column) 140 elif select_candidate.name not in expression.named_selects: 141 expression.select(select_candidate.copy(), copy=False) 142 143 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 144 qualify_filters, copy=False 145 ) 146 147 return expression 148 149 150def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 151 """ 152 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 153 other expressions. This transforms removes the precision from parameterized types in expressions. 154 """ 155 for node in expression.find_all(exp.DataType): 156 node.set( 157 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 158 ) 159 160 return expression 161 162 163def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 164 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 165 from sqlglot.optimizer.scope import find_all_in_scope 166 167 if isinstance(expression, exp.Select): 168 unnest_aliases = { 169 unnest.alias 170 for unnest in find_all_in_scope(expression, exp.Unnest) 171 if isinstance(unnest.parent, (exp.From, exp.Join)) 172 } 173 if unnest_aliases: 174 for column in expression.find_all(exp.Column): 175 if column.table in unnest_aliases: 176 column.set("table", None) 177 elif column.db in unnest_aliases: 178 column.set("db", None) 179 180 return expression 181 182 183def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 184 """Convert cross join unnest into lateral view explode.""" 185 if isinstance(expression, exp.Select): 186 for join in expression.args.get("joins") or []: 187 unnest = join.this 188 189 if isinstance(unnest, exp.Unnest): 190 alias = unnest.args.get("alias") 191 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 192 193 expression.args["joins"].remove(join) 194 195 for e, column in zip(unnest.expressions, alias.columns if alias else []): 196 expression.append( 197 "laterals", 198 exp.Lateral( 199 this=udtf(this=e), 200 view=True, 201 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 202 ), 203 ) 204 205 return expression 206 207 208def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 209 """Convert explode/posexplode into unnest.""" 210 211 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 212 if isinstance(expression, exp.Select): 213 from sqlglot.optimizer.scope import Scope 214 215 taken_select_names = set(expression.named_selects) 216 taken_source_names = {name for name, _ in Scope(expression).references} 217 218 def new_name(names: t.Set[str], name: str) -> str: 219 name = find_new_name(names, name) 220 names.add(name) 221 return name 222 223 arrays: t.List[exp.Condition] = [] 224 series_alias = new_name(taken_select_names, "pos") 225 series = exp.alias_( 226 exp.Unnest( 227 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 228 ), 229 new_name(taken_source_names, "_u"), 230 table=[series_alias], 231 ) 232 233 # we use list here because expression.selects is mutated inside the loop 234 for select in list(expression.selects): 235 explode = select.find(exp.Explode) 236 237 if explode: 238 pos_alias = "" 239 explode_alias = "" 240 241 if isinstance(select, exp.Alias): 242 explode_alias = select.args["alias"] 243 alias = select 244 elif isinstance(select, exp.Aliases): 245 pos_alias = select.aliases[0] 246 explode_alias = select.aliases[1] 247 alias = select.replace(exp.alias_(select.this, "", copy=False)) 248 else: 249 alias = select.replace(exp.alias_(select, "")) 250 explode = alias.find(exp.Explode) 251 assert explode 252 253 is_posexplode = isinstance(explode, exp.Posexplode) 254 explode_arg = explode.this 255 256 if isinstance(explode, exp.ExplodeOuter): 257 bracket = explode_arg[0] 258 bracket.set("safe", True) 259 bracket.set("offset", True) 260 explode_arg = exp.func( 261 "IF", 262 exp.func( 263 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 264 ).eq(0), 265 exp.array(bracket, copy=False), 266 explode_arg, 267 ) 268 269 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 270 if isinstance(explode_arg, exp.Column): 271 taken_select_names.add(explode_arg.output_name) 272 273 unnest_source_alias = new_name(taken_source_names, "_u") 274 275 if not explode_alias: 276 explode_alias = new_name(taken_select_names, "col") 277 278 if is_posexplode: 279 pos_alias = new_name(taken_select_names, "pos") 280 281 if not pos_alias: 282 pos_alias = new_name(taken_select_names, "pos") 283 284 alias.set("alias", exp.to_identifier(explode_alias)) 285 286 series_table_alias = series.args["alias"].this 287 column = exp.If( 288 this=exp.column(series_alias, table=series_table_alias).eq( 289 exp.column(pos_alias, table=unnest_source_alias) 290 ), 291 true=exp.column(explode_alias, table=unnest_source_alias), 292 ) 293 294 explode.replace(column) 295 296 if is_posexplode: 297 expressions = expression.expressions 298 expressions.insert( 299 expressions.index(alias) + 1, 300 exp.If( 301 this=exp.column(series_alias, table=series_table_alias).eq( 302 exp.column(pos_alias, table=unnest_source_alias) 303 ), 304 true=exp.column(pos_alias, table=unnest_source_alias), 305 ).as_(pos_alias), 306 ) 307 expression.set("expressions", expressions) 308 309 if not arrays: 310 if expression.args.get("from"): 311 expression.join(series, copy=False, join_type="CROSS") 312 else: 313 expression.from_(series, copy=False) 314 315 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 316 arrays.append(size) 317 318 # trino doesn't support left join unnest with on conditions 319 # if it did, this would be much simpler 320 expression.join( 321 exp.alias_( 322 exp.Unnest( 323 expressions=[explode_arg.copy()], 324 offset=exp.to_identifier(pos_alias), 325 ), 326 unnest_source_alias, 327 table=[explode_alias], 328 ), 329 join_type="CROSS", 330 copy=False, 331 ) 332 333 if index_offset != 1: 334 size = size - 1 335 336 expression.where( 337 exp.column(series_alias, table=series_table_alias) 338 .eq(exp.column(pos_alias, table=unnest_source_alias)) 339 .or_( 340 (exp.column(series_alias, table=series_table_alias) > size).and_( 341 exp.column(pos_alias, table=unnest_source_alias).eq(size) 342 ) 343 ), 344 copy=False, 345 ) 346 347 if arrays: 348 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 349 350 if index_offset != 1: 351 end = end - (1 - index_offset) 352 series.expressions[0].set("end", end) 353 354 return expression 355 356 return _explode_to_unnest 357 358 359def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 360 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 361 if ( 362 isinstance(expression, exp.PERCENTILES) 363 and not isinstance(expression.parent, exp.WithinGroup) 364 and expression.expression 365 ): 366 column = expression.this.pop() 367 expression.set("this", expression.expression.pop()) 368 order = exp.Order(expressions=[exp.Ordered(this=column)]) 369 expression = exp.WithinGroup(this=expression, expression=order) 370 371 return expression 372 373 374def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 375 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 376 if ( 377 isinstance(expression, exp.WithinGroup) 378 and isinstance(expression.this, exp.PERCENTILES) 379 and isinstance(expression.expression, exp.Order) 380 ): 381 quantile = expression.this.this 382 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 383 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 384 385 return expression 386 387 388def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 389 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 390 if isinstance(expression, exp.With) and expression.recursive: 391 next_name = name_sequence("_c_") 392 393 for cte in expression.expressions: 394 if not cte.args["alias"].columns: 395 query = cte.this 396 if isinstance(query, exp.Union): 397 query = query.this 398 399 cte.args["alias"].set( 400 "columns", 401 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 402 ) 403 404 return expression 405 406 407def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 408 """Replace 'epoch' in casts by the equivalent date literal.""" 409 if ( 410 isinstance(expression, (exp.Cast, exp.TryCast)) 411 and expression.name.lower() == "epoch" 412 and expression.to.this in exp.DataType.TEMPORAL_TYPES 413 ): 414 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 415 416 return expression 417 418 419def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 420 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 421 if isinstance(expression, exp.Select): 422 for join in expression.args.get("joins") or []: 423 on = join.args.get("on") 424 if on and join.kind in ("SEMI", "ANTI"): 425 subquery = exp.select("1").from_(join.this).where(on) 426 exists = exp.Exists(this=subquery) 427 if join.kind == "ANTI": 428 exists = exists.not_(copy=False) 429 430 join.pop() 431 expression.where(exists, copy=False) 432 433 return expression 434 435 436def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 437 """ 438 Converts a query with a FULL OUTER join to a union of identical queries that 439 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 440 for queries that have a single FULL OUTER join. 441 """ 442 if isinstance(expression, exp.Select): 443 full_outer_joins = [ 444 (index, join) 445 for index, join in enumerate(expression.args.get("joins") or []) 446 if join.side == "FULL" 447 ] 448 449 if len(full_outer_joins) == 1: 450 expression_copy = expression.copy() 451 expression.set("limit", None) 452 index, full_outer_join = full_outer_joins[0] 453 full_outer_join.set("side", "left") 454 expression_copy.args["joins"][index].set("side", "right") 455 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 456 457 return exp.union(expression, expression_copy, copy=False) 458 459 return expression 460 461 462def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 463 """ 464 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 465 defined at the top-level, so for example queries like: 466 467 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 468 469 are invalid in those dialects. This transformation can be used to ensure all CTEs are 470 moved to the top level so that the final SQL code is valid from a syntax standpoint. 471 472 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 473 """ 474 top_level_with = expression.args.get("with") 475 for inner_with in expression.find_all(exp.With): 476 if inner_with.parent is expression: 477 continue 478 479 if not top_level_with: 480 top_level_with = inner_with.pop() 481 expression.set("with", top_level_with) 482 else: 483 if inner_with.recursive: 484 top_level_with.set("recursive", True) 485 486 parent_cte = inner_with.find_ancestor(exp.CTE) 487 inner_with.pop() 488 489 if parent_cte: 490 i = top_level_with.expressions.index(parent_cte) 491 top_level_with.expressions[i:i] = inner_with.expressions 492 top_level_with.set("expressions", top_level_with.expressions) 493 else: 494 top_level_with.set( 495 "expressions", top_level_with.expressions + inner_with.expressions 496 ) 497 498 return expression 499 500 501def ensure_bools(expression: exp.Expression) -> exp.Expression: 502 """Converts numeric values used in conditions into explicit boolean expressions.""" 503 from sqlglot.optimizer.canonicalize import ensure_bools 504 505 def _ensure_bool(node: exp.Expression) -> None: 506 if ( 507 node.is_number 508 or ( 509 not isinstance(node, exp.SubqueryPredicate) 510 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 511 ) 512 or (isinstance(node, exp.Column) and not node.type) 513 ): 514 node.replace(node.neq(0)) 515 516 for node in expression.walk(): 517 ensure_bools(node, _ensure_bool) 518 519 return expression 520 521 522def unqualify_columns(expression: exp.Expression) -> exp.Expression: 523 for column in expression.find_all(exp.Column): 524 # We only wanna pop off the table, db, catalog args 525 for part in column.parts[:-1]: 526 part.pop() 527 528 return expression 529 530 531def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 532 assert isinstance(expression, exp.Create) 533 for constraint in expression.find_all(exp.UniqueColumnConstraint): 534 if constraint.parent: 535 constraint.parent.pop() 536 537 return expression 538 539 540def ctas_with_tmp_tables_to_create_tmp_view( 541 expression: exp.Expression, 542 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 543) -> exp.Expression: 544 assert isinstance(expression, exp.Create) 545 properties = expression.args.get("properties") 546 temporary = any( 547 isinstance(prop, exp.TemporaryProperty) 548 for prop in (properties.expressions if properties else []) 549 ) 550 551 # CTAS with temp tables map to CREATE TEMPORARY VIEW 552 if expression.kind == "TABLE" and temporary: 553 if expression.expression: 554 return exp.Create( 555 kind="TEMPORARY VIEW", 556 this=expression.this, 557 expression=expression.expression, 558 ) 559 return tmp_storage_provider(expression) 560 561 return expression 562 563 564def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 565 """ 566 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 567 PARTITIONED BY value is an array of column names, they are transformed into a schema. 568 The corresponding columns are removed from the create statement. 569 """ 570 assert isinstance(expression, exp.Create) 571 has_schema = isinstance(expression.this, exp.Schema) 572 is_partitionable = expression.kind in {"TABLE", "VIEW"} 573 574 if has_schema and is_partitionable: 575 prop = expression.find(exp.PartitionedByProperty) 576 if prop and prop.this and not isinstance(prop.this, exp.Schema): 577 schema = expression.this 578 columns = {v.name.upper() for v in prop.this.expressions} 579 partitions = [col for col in schema.expressions if col.name.upper() in columns] 580 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 581 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 582 expression.set("this", schema) 583 584 return expression 585 586 587def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 588 """ 589 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 590 591 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 592 """ 593 assert isinstance(expression, exp.Create) 594 prop = expression.find(exp.PartitionedByProperty) 595 if ( 596 prop 597 and prop.this 598 and isinstance(prop.this, exp.Schema) 599 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 600 ): 601 prop_this = exp.Tuple( 602 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 603 ) 604 schema = expression.this 605 for e in prop.this.expressions: 606 schema.append("expressions", e) 607 prop.set("this", prop_this) 608 609 return expression 610 611 612def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 613 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 614 if isinstance(expression, exp.Struct): 615 expression.set( 616 "expressions", 617 [ 618 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 619 for e in expression.expressions 620 ], 621 ) 622 623 return expression 624 625 626def preprocess( 627 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 628) -> t.Callable[[Generator, exp.Expression], str]: 629 """ 630 Creates a new transform by chaining a sequence of transformations and converts the resulting 631 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 632 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 633 634 Args: 635 transforms: sequence of transform functions. These will be called in order. 636 637 Returns: 638 Function that can be used as a generator transform. 639 """ 640 641 def _to_sql(self, expression: exp.Expression) -> str: 642 expression_type = type(expression) 643 644 expression = transforms[0](expression) 645 for transform in transforms[1:]: 646 expression = transform(expression) 647 648 _sql_handler = getattr(self, expression.key + "_sql", None) 649 if _sql_handler: 650 return _sql_handler(expression) 651 652 transforms_handler = self.TRANSFORMS.get(type(expression)) 653 if transforms_handler: 654 if expression_type is type(expression): 655 if isinstance(expression, exp.Func): 656 return self.function_fallback_sql(expression) 657 658 # Ensures we don't enter an infinite loop. This can happen when the original expression 659 # has the same type as the final expression and there's no _sql method available for it, 660 # because then it'd re-enter _to_sql. 661 raise ValueError( 662 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 663 ) 664 665 return transforms_handler(self, expression) 666 667 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 668 669 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 98 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 99 corresponding expression to avoid creating invalid column references. 100 """ 101 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 102 taken = set(expression.named_selects) 103 for select in expression.selects: 104 if not select.alias_or_name: 105 alias = find_new_name(taken, "_c") 106 select.replace(exp.alias_(select, alias)) 107 taken.add(alias) 108 109 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 110 alias_or_name = select.alias_or_name 111 identifier = select.args.get("alias") or select.this 112 if isinstance(identifier, exp.Identifier): 113 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 114 return alias_or_name 115 116 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 117 qualify_filters = expression.args["qualify"].pop().this 118 expression_by_alias = { 119 select.alias: select.this 120 for select in expression.selects 121 if isinstance(select, exp.Alias) 122 } 123 124 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 125 for select_candidate in qualify_filters.find_all(select_candidates): 126 if isinstance(select_candidate, exp.Window): 127 if expression_by_alias: 128 for column in select_candidate.find_all(exp.Column): 129 expr = expression_by_alias.get(column.name) 130 if expr: 131 column.replace(expr) 132 133 alias = find_new_name(expression.named_selects, "_w") 134 expression.select(exp.alias_(select_candidate, alias), copy=False) 135 column = exp.column(alias) 136 137 if isinstance(select_candidate.parent, exp.Qualify): 138 qualify_filters = column 139 else: 140 select_candidate.replace(column) 141 elif select_candidate.name not in expression.named_selects: 142 expression.select(select_candidate.copy(), copy=False) 143 144 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 145 qualify_filters, copy=False 146 ) 147 148 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
151def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 152 """ 153 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 154 other expressions. This transforms removes the precision from parameterized types in expressions. 155 """ 156 for node in expression.find_all(exp.DataType): 157 node.set( 158 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 159 ) 160 161 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
164def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 165 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 166 from sqlglot.optimizer.scope import find_all_in_scope 167 168 if isinstance(expression, exp.Select): 169 unnest_aliases = { 170 unnest.alias 171 for unnest in find_all_in_scope(expression, exp.Unnest) 172 if isinstance(unnest.parent, (exp.From, exp.Join)) 173 } 174 if unnest_aliases: 175 for column in expression.find_all(exp.Column): 176 if column.table in unnest_aliases: 177 column.set("table", None) 178 elif column.db in unnest_aliases: 179 column.set("db", None) 180 181 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
184def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 185 """Convert cross join unnest into lateral view explode.""" 186 if isinstance(expression, exp.Select): 187 for join in expression.args.get("joins") or []: 188 unnest = join.this 189 190 if isinstance(unnest, exp.Unnest): 191 alias = unnest.args.get("alias") 192 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 193 194 expression.args["joins"].remove(join) 195 196 for e, column in zip(unnest.expressions, alias.columns if alias else []): 197 expression.append( 198 "laterals", 199 exp.Lateral( 200 this=udtf(this=e), 201 view=True, 202 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 203 ), 204 ) 205 206 return expression
Convert cross join unnest into lateral view explode.
209def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 210 """Convert explode/posexplode into unnest.""" 211 212 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 213 if isinstance(expression, exp.Select): 214 from sqlglot.optimizer.scope import Scope 215 216 taken_select_names = set(expression.named_selects) 217 taken_source_names = {name for name, _ in Scope(expression).references} 218 219 def new_name(names: t.Set[str], name: str) -> str: 220 name = find_new_name(names, name) 221 names.add(name) 222 return name 223 224 arrays: t.List[exp.Condition] = [] 225 series_alias = new_name(taken_select_names, "pos") 226 series = exp.alias_( 227 exp.Unnest( 228 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 229 ), 230 new_name(taken_source_names, "_u"), 231 table=[series_alias], 232 ) 233 234 # we use list here because expression.selects is mutated inside the loop 235 for select in list(expression.selects): 236 explode = select.find(exp.Explode) 237 238 if explode: 239 pos_alias = "" 240 explode_alias = "" 241 242 if isinstance(select, exp.Alias): 243 explode_alias = select.args["alias"] 244 alias = select 245 elif isinstance(select, exp.Aliases): 246 pos_alias = select.aliases[0] 247 explode_alias = select.aliases[1] 248 alias = select.replace(exp.alias_(select.this, "", copy=False)) 249 else: 250 alias = select.replace(exp.alias_(select, "")) 251 explode = alias.find(exp.Explode) 252 assert explode 253 254 is_posexplode = isinstance(explode, exp.Posexplode) 255 explode_arg = explode.this 256 257 if isinstance(explode, exp.ExplodeOuter): 258 bracket = explode_arg[0] 259 bracket.set("safe", True) 260 bracket.set("offset", True) 261 explode_arg = exp.func( 262 "IF", 263 exp.func( 264 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 265 ).eq(0), 266 exp.array(bracket, copy=False), 267 explode_arg, 268 ) 269 270 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 271 if isinstance(explode_arg, exp.Column): 272 taken_select_names.add(explode_arg.output_name) 273 274 unnest_source_alias = new_name(taken_source_names, "_u") 275 276 if not explode_alias: 277 explode_alias = new_name(taken_select_names, "col") 278 279 if is_posexplode: 280 pos_alias = new_name(taken_select_names, "pos") 281 282 if not pos_alias: 283 pos_alias = new_name(taken_select_names, "pos") 284 285 alias.set("alias", exp.to_identifier(explode_alias)) 286 287 series_table_alias = series.args["alias"].this 288 column = exp.If( 289 this=exp.column(series_alias, table=series_table_alias).eq( 290 exp.column(pos_alias, table=unnest_source_alias) 291 ), 292 true=exp.column(explode_alias, table=unnest_source_alias), 293 ) 294 295 explode.replace(column) 296 297 if is_posexplode: 298 expressions = expression.expressions 299 expressions.insert( 300 expressions.index(alias) + 1, 301 exp.If( 302 this=exp.column(series_alias, table=series_table_alias).eq( 303 exp.column(pos_alias, table=unnest_source_alias) 304 ), 305 true=exp.column(pos_alias, table=unnest_source_alias), 306 ).as_(pos_alias), 307 ) 308 expression.set("expressions", expressions) 309 310 if not arrays: 311 if expression.args.get("from"): 312 expression.join(series, copy=False, join_type="CROSS") 313 else: 314 expression.from_(series, copy=False) 315 316 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 317 arrays.append(size) 318 319 # trino doesn't support left join unnest with on conditions 320 # if it did, this would be much simpler 321 expression.join( 322 exp.alias_( 323 exp.Unnest( 324 expressions=[explode_arg.copy()], 325 offset=exp.to_identifier(pos_alias), 326 ), 327 unnest_source_alias, 328 table=[explode_alias], 329 ), 330 join_type="CROSS", 331 copy=False, 332 ) 333 334 if index_offset != 1: 335 size = size - 1 336 337 expression.where( 338 exp.column(series_alias, table=series_table_alias) 339 .eq(exp.column(pos_alias, table=unnest_source_alias)) 340 .or_( 341 (exp.column(series_alias, table=series_table_alias) > size).and_( 342 exp.column(pos_alias, table=unnest_source_alias).eq(size) 343 ) 344 ), 345 copy=False, 346 ) 347 348 if arrays: 349 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 350 351 if index_offset != 1: 352 end = end - (1 - index_offset) 353 series.expressions[0].set("end", end) 354 355 return expression 356 357 return _explode_to_unnest
Convert explode/posexplode into unnest.
360def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 361 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 362 if ( 363 isinstance(expression, exp.PERCENTILES) 364 and not isinstance(expression.parent, exp.WithinGroup) 365 and expression.expression 366 ): 367 column = expression.this.pop() 368 expression.set("this", expression.expression.pop()) 369 order = exp.Order(expressions=[exp.Ordered(this=column)]) 370 expression = exp.WithinGroup(this=expression, expression=order) 371 372 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
375def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 376 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 377 if ( 378 isinstance(expression, exp.WithinGroup) 379 and isinstance(expression.this, exp.PERCENTILES) 380 and isinstance(expression.expression, exp.Order) 381 ): 382 quantile = expression.this.this 383 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 384 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 385 386 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
389def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 390 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 391 if isinstance(expression, exp.With) and expression.recursive: 392 next_name = name_sequence("_c_") 393 394 for cte in expression.expressions: 395 if not cte.args["alias"].columns: 396 query = cte.this 397 if isinstance(query, exp.Union): 398 query = query.this 399 400 cte.args["alias"].set( 401 "columns", 402 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 403 ) 404 405 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
408def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 409 """Replace 'epoch' in casts by the equivalent date literal.""" 410 if ( 411 isinstance(expression, (exp.Cast, exp.TryCast)) 412 and expression.name.lower() == "epoch" 413 and expression.to.this in exp.DataType.TEMPORAL_TYPES 414 ): 415 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 416 417 return expression
Replace 'epoch' in casts by the equivalent date literal.
420def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 421 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 422 if isinstance(expression, exp.Select): 423 for join in expression.args.get("joins") or []: 424 on = join.args.get("on") 425 if on and join.kind in ("SEMI", "ANTI"): 426 subquery = exp.select("1").from_(join.this).where(on) 427 exists = exp.Exists(this=subquery) 428 if join.kind == "ANTI": 429 exists = exists.not_(copy=False) 430 431 join.pop() 432 expression.where(exists, copy=False) 433 434 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
437def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 438 """ 439 Converts a query with a FULL OUTER join to a union of identical queries that 440 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 441 for queries that have a single FULL OUTER join. 442 """ 443 if isinstance(expression, exp.Select): 444 full_outer_joins = [ 445 (index, join) 446 for index, join in enumerate(expression.args.get("joins") or []) 447 if join.side == "FULL" 448 ] 449 450 if len(full_outer_joins) == 1: 451 expression_copy = expression.copy() 452 expression.set("limit", None) 453 index, full_outer_join = full_outer_joins[0] 454 full_outer_join.set("side", "left") 455 expression_copy.args["joins"][index].set("side", "right") 456 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 457 458 return exp.union(expression, expression_copy, copy=False) 459 460 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
463def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 464 """ 465 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 466 defined at the top-level, so for example queries like: 467 468 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 469 470 are invalid in those dialects. This transformation can be used to ensure all CTEs are 471 moved to the top level so that the final SQL code is valid from a syntax standpoint. 472 473 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 474 """ 475 top_level_with = expression.args.get("with") 476 for inner_with in expression.find_all(exp.With): 477 if inner_with.parent is expression: 478 continue 479 480 if not top_level_with: 481 top_level_with = inner_with.pop() 482 expression.set("with", top_level_with) 483 else: 484 if inner_with.recursive: 485 top_level_with.set("recursive", True) 486 487 parent_cte = inner_with.find_ancestor(exp.CTE) 488 inner_with.pop() 489 490 if parent_cte: 491 i = top_level_with.expressions.index(parent_cte) 492 top_level_with.expressions[i:i] = inner_with.expressions 493 top_level_with.set("expressions", top_level_with.expressions) 494 else: 495 top_level_with.set( 496 "expressions", top_level_with.expressions + inner_with.expressions 497 ) 498 499 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
502def ensure_bools(expression: exp.Expression) -> exp.Expression: 503 """Converts numeric values used in conditions into explicit boolean expressions.""" 504 from sqlglot.optimizer.canonicalize import ensure_bools 505 506 def _ensure_bool(node: exp.Expression) -> None: 507 if ( 508 node.is_number 509 or ( 510 not isinstance(node, exp.SubqueryPredicate) 511 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 512 ) 513 or (isinstance(node, exp.Column) and not node.type) 514 ): 515 node.replace(node.neq(0)) 516 517 for node in expression.walk(): 518 ensure_bools(node, _ensure_bool) 519 520 return expression
Converts numeric values used in conditions into explicit boolean expressions.
541def ctas_with_tmp_tables_to_create_tmp_view( 542 expression: exp.Expression, 543 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 544) -> exp.Expression: 545 assert isinstance(expression, exp.Create) 546 properties = expression.args.get("properties") 547 temporary = any( 548 isinstance(prop, exp.TemporaryProperty) 549 for prop in (properties.expressions if properties else []) 550 ) 551 552 # CTAS with temp tables map to CREATE TEMPORARY VIEW 553 if expression.kind == "TABLE" and temporary: 554 if expression.expression: 555 return exp.Create( 556 kind="TEMPORARY VIEW", 557 this=expression.this, 558 expression=expression.expression, 559 ) 560 return tmp_storage_provider(expression) 561 562 return expression
565def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 566 """ 567 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 568 PARTITIONED BY value is an array of column names, they are transformed into a schema. 569 The corresponding columns are removed from the create statement. 570 """ 571 assert isinstance(expression, exp.Create) 572 has_schema = isinstance(expression.this, exp.Schema) 573 is_partitionable = expression.kind in {"TABLE", "VIEW"} 574 575 if has_schema and is_partitionable: 576 prop = expression.find(exp.PartitionedByProperty) 577 if prop and prop.this and not isinstance(prop.this, exp.Schema): 578 schema = expression.this 579 columns = {v.name.upper() for v in prop.this.expressions} 580 partitions = [col for col in schema.expressions if col.name.upper() in columns] 581 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 582 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 583 expression.set("this", schema) 584 585 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
588def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 589 """ 590 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 591 592 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 593 """ 594 assert isinstance(expression, exp.Create) 595 prop = expression.find(exp.PartitionedByProperty) 596 if ( 597 prop 598 and prop.this 599 and isinstance(prop.this, exp.Schema) 600 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 601 ): 602 prop_this = exp.Tuple( 603 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 604 ) 605 schema = expression.this 606 for e in prop.this.expressions: 607 schema.append("expressions", e) 608 prop.set("this", prop_this) 609 610 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
613def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 614 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 615 if isinstance(expression, exp.Struct): 616 expression.set( 617 "expressions", 618 [ 619 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 620 for e in expression.expressions 621 ], 622 ) 623 624 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
627def preprocess( 628 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 629) -> t.Callable[[Generator, exp.Expression], str]: 630 """ 631 Creates a new transform by chaining a sequence of transformations and converts the resulting 632 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 633 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 634 635 Args: 636 transforms: sequence of transform functions. These will be called in order. 637 638 Returns: 639 Function that can be used as a generator transform. 640 """ 641 642 def _to_sql(self, expression: exp.Expression) -> str: 643 expression_type = type(expression) 644 645 expression = transforms[0](expression) 646 for transform in transforms[1:]: 647 expression = transform(expression) 648 649 _sql_handler = getattr(self, expression.key + "_sql", None) 650 if _sql_handler: 651 return _sql_handler(expression) 652 653 transforms_handler = self.TRANSFORMS.get(type(expression)) 654 if transforms_handler: 655 if expression_type is type(expression): 656 if isinstance(expression, exp.Func): 657 return self.function_fallback_sql(expression) 658 659 # Ensures we don't enter an infinite loop. This can happen when the original expression 660 # has the same type as the final expression and there's no _sql method available for it, 661 # because then it'd re-enter _to_sql. 662 raise ValueError( 663 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 664 ) 665 666 return transforms_handler(self, expression) 667 668 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 669 670 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.