Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 97    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 98    corresponding expression to avoid creating invalid column references.
 99    """
100    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
101        taken = set(expression.named_selects)
102        for select in expression.selects:
103            if not select.alias_or_name:
104                alias = find_new_name(taken, "_c")
105                select.replace(exp.alias_(select, alias))
106                taken.add(alias)
107
108        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
109            alias_or_name = select.alias_or_name
110            identifier = select.args.get("alias") or select.this
111            if isinstance(identifier, exp.Identifier):
112                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
113            return alias_or_name
114
115        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
116        qualify_filters = expression.args["qualify"].pop().this
117        expression_by_alias = {
118            select.alias: select.this
119            for select in expression.selects
120            if isinstance(select, exp.Alias)
121        }
122
123        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
124        for select_candidate in qualify_filters.find_all(select_candidates):
125            if isinstance(select_candidate, exp.Window):
126                if expression_by_alias:
127                    for column in select_candidate.find_all(exp.Column):
128                        expr = expression_by_alias.get(column.name)
129                        if expr:
130                            column.replace(expr)
131
132                alias = find_new_name(expression.named_selects, "_w")
133                expression.select(exp.alias_(select_candidate, alias), copy=False)
134                column = exp.column(alias)
135
136                if isinstance(select_candidate.parent, exp.Qualify):
137                    qualify_filters = column
138                else:
139                    select_candidate.replace(column)
140            elif select_candidate.name not in expression.named_selects:
141                expression.select(select_candidate.copy(), copy=False)
142
143        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
144            qualify_filters, copy=False
145        )
146
147    return expression
148
149
150def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
151    """
152    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
153    other expressions. This transforms removes the precision from parameterized types in expressions.
154    """
155    for node in expression.find_all(exp.DataType):
156        node.set(
157            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
158        )
159
160    return expression
161
162
163def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
164    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
165    from sqlglot.optimizer.scope import find_all_in_scope
166
167    if isinstance(expression, exp.Select):
168        unnest_aliases = {
169            unnest.alias
170            for unnest in find_all_in_scope(expression, exp.Unnest)
171            if isinstance(unnest.parent, (exp.From, exp.Join))
172        }
173        if unnest_aliases:
174            for column in expression.find_all(exp.Column):
175                if column.table in unnest_aliases:
176                    column.set("table", None)
177                elif column.db in unnest_aliases:
178                    column.set("db", None)
179
180    return expression
181
182
183def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
184    """Convert cross join unnest into lateral view explode."""
185    if isinstance(expression, exp.Select):
186        for join in expression.args.get("joins") or []:
187            unnest = join.this
188
189            if isinstance(unnest, exp.Unnest):
190                alias = unnest.args.get("alias")
191                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
192
193                expression.args["joins"].remove(join)
194
195                for e, column in zip(unnest.expressions, alias.columns if alias else []):
196                    expression.append(
197                        "laterals",
198                        exp.Lateral(
199                            this=udtf(this=e),
200                            view=True,
201                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
202                        ),
203                    )
204
205    return expression
206
207
208def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
209    """Convert explode/posexplode into unnest."""
210
211    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
212        if isinstance(expression, exp.Select):
213            from sqlglot.optimizer.scope import Scope
214
215            taken_select_names = set(expression.named_selects)
216            taken_source_names = {name for name, _ in Scope(expression).references}
217
218            def new_name(names: t.Set[str], name: str) -> str:
219                name = find_new_name(names, name)
220                names.add(name)
221                return name
222
223            arrays: t.List[exp.Condition] = []
224            series_alias = new_name(taken_select_names, "pos")
225            series = exp.alias_(
226                exp.Unnest(
227                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
228                ),
229                new_name(taken_source_names, "_u"),
230                table=[series_alias],
231            )
232
233            # we use list here because expression.selects is mutated inside the loop
234            for select in list(expression.selects):
235                explode = select.find(exp.Explode)
236
237                if explode:
238                    pos_alias = ""
239                    explode_alias = ""
240
241                    if isinstance(select, exp.Alias):
242                        explode_alias = select.args["alias"]
243                        alias = select
244                    elif isinstance(select, exp.Aliases):
245                        pos_alias = select.aliases[0]
246                        explode_alias = select.aliases[1]
247                        alias = select.replace(exp.alias_(select.this, "", copy=False))
248                    else:
249                        alias = select.replace(exp.alias_(select, ""))
250                        explode = alias.find(exp.Explode)
251                        assert explode
252
253                    is_posexplode = isinstance(explode, exp.Posexplode)
254                    explode_arg = explode.this
255
256                    if isinstance(explode, exp.ExplodeOuter):
257                        bracket = explode_arg[0]
258                        bracket.set("safe", True)
259                        bracket.set("offset", True)
260                        explode_arg = exp.func(
261                            "IF",
262                            exp.func(
263                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
264                            ).eq(0),
265                            exp.array(bracket, copy=False),
266                            explode_arg,
267                        )
268
269                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
270                    if isinstance(explode_arg, exp.Column):
271                        taken_select_names.add(explode_arg.output_name)
272
273                    unnest_source_alias = new_name(taken_source_names, "_u")
274
275                    if not explode_alias:
276                        explode_alias = new_name(taken_select_names, "col")
277
278                        if is_posexplode:
279                            pos_alias = new_name(taken_select_names, "pos")
280
281                    if not pos_alias:
282                        pos_alias = new_name(taken_select_names, "pos")
283
284                    alias.set("alias", exp.to_identifier(explode_alias))
285
286                    series_table_alias = series.args["alias"].this
287                    column = exp.If(
288                        this=exp.column(series_alias, table=series_table_alias).eq(
289                            exp.column(pos_alias, table=unnest_source_alias)
290                        ),
291                        true=exp.column(explode_alias, table=unnest_source_alias),
292                    )
293
294                    explode.replace(column)
295
296                    if is_posexplode:
297                        expressions = expression.expressions
298                        expressions.insert(
299                            expressions.index(alias) + 1,
300                            exp.If(
301                                this=exp.column(series_alias, table=series_table_alias).eq(
302                                    exp.column(pos_alias, table=unnest_source_alias)
303                                ),
304                                true=exp.column(pos_alias, table=unnest_source_alias),
305                            ).as_(pos_alias),
306                        )
307                        expression.set("expressions", expressions)
308
309                    if not arrays:
310                        if expression.args.get("from"):
311                            expression.join(series, copy=False, join_type="CROSS")
312                        else:
313                            expression.from_(series, copy=False)
314
315                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
316                    arrays.append(size)
317
318                    # trino doesn't support left join unnest with on conditions
319                    # if it did, this would be much simpler
320                    expression.join(
321                        exp.alias_(
322                            exp.Unnest(
323                                expressions=[explode_arg.copy()],
324                                offset=exp.to_identifier(pos_alias),
325                            ),
326                            unnest_source_alias,
327                            table=[explode_alias],
328                        ),
329                        join_type="CROSS",
330                        copy=False,
331                    )
332
333                    if index_offset != 1:
334                        size = size - 1
335
336                    expression.where(
337                        exp.column(series_alias, table=series_table_alias)
338                        .eq(exp.column(pos_alias, table=unnest_source_alias))
339                        .or_(
340                            (exp.column(series_alias, table=series_table_alias) > size).and_(
341                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
342                            )
343                        ),
344                        copy=False,
345                    )
346
347            if arrays:
348                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
349
350                if index_offset != 1:
351                    end = end - (1 - index_offset)
352                series.expressions[0].set("end", end)
353
354        return expression
355
356    return _explode_to_unnest
357
358
359def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
360    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
361    if (
362        isinstance(expression, exp.PERCENTILES)
363        and not isinstance(expression.parent, exp.WithinGroup)
364        and expression.expression
365    ):
366        column = expression.this.pop()
367        expression.set("this", expression.expression.pop())
368        order = exp.Order(expressions=[exp.Ordered(this=column)])
369        expression = exp.WithinGroup(this=expression, expression=order)
370
371    return expression
372
373
374def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
375    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
376    if (
377        isinstance(expression, exp.WithinGroup)
378        and isinstance(expression.this, exp.PERCENTILES)
379        and isinstance(expression.expression, exp.Order)
380    ):
381        quantile = expression.this.this
382        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
383        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
384
385    return expression
386
387
388def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
389    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
390    if isinstance(expression, exp.With) and expression.recursive:
391        next_name = name_sequence("_c_")
392
393        for cte in expression.expressions:
394            if not cte.args["alias"].columns:
395                query = cte.this
396                if isinstance(query, exp.Union):
397                    query = query.this
398
399                cte.args["alias"].set(
400                    "columns",
401                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
402                )
403
404    return expression
405
406
407def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
408    """Replace 'epoch' in casts by the equivalent date literal."""
409    if (
410        isinstance(expression, (exp.Cast, exp.TryCast))
411        and expression.name.lower() == "epoch"
412        and expression.to.this in exp.DataType.TEMPORAL_TYPES
413    ):
414        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
415
416    return expression
417
418
419def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
420    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
421    if isinstance(expression, exp.Select):
422        for join in expression.args.get("joins") or []:
423            on = join.args.get("on")
424            if on and join.kind in ("SEMI", "ANTI"):
425                subquery = exp.select("1").from_(join.this).where(on)
426                exists = exp.Exists(this=subquery)
427                if join.kind == "ANTI":
428                    exists = exists.not_(copy=False)
429
430                join.pop()
431                expression.where(exists, copy=False)
432
433    return expression
434
435
436def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
437    """
438    Converts a query with a FULL OUTER join to a union of identical queries that
439    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
440    for queries that have a single FULL OUTER join.
441    """
442    if isinstance(expression, exp.Select):
443        full_outer_joins = [
444            (index, join)
445            for index, join in enumerate(expression.args.get("joins") or [])
446            if join.side == "FULL"
447        ]
448
449        if len(full_outer_joins) == 1:
450            expression_copy = expression.copy()
451            expression.set("limit", None)
452            index, full_outer_join = full_outer_joins[0]
453            full_outer_join.set("side", "left")
454            expression_copy.args["joins"][index].set("side", "right")
455            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
456
457            return exp.union(expression, expression_copy, copy=False)
458
459    return expression
460
461
462def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
463    """
464    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
465    defined at the top-level, so for example queries like:
466
467        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
468
469    are invalid in those dialects. This transformation can be used to ensure all CTEs are
470    moved to the top level so that the final SQL code is valid from a syntax standpoint.
471
472    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
473    """
474    top_level_with = expression.args.get("with")
475    for inner_with in expression.find_all(exp.With):
476        if inner_with.parent is expression:
477            continue
478
479        if not top_level_with:
480            top_level_with = inner_with.pop()
481            expression.set("with", top_level_with)
482        else:
483            if inner_with.recursive:
484                top_level_with.set("recursive", True)
485
486            parent_cte = inner_with.find_ancestor(exp.CTE)
487            inner_with.pop()
488
489            if parent_cte:
490                i = top_level_with.expressions.index(parent_cte)
491                top_level_with.expressions[i:i] = inner_with.expressions
492                top_level_with.set("expressions", top_level_with.expressions)
493            else:
494                top_level_with.set(
495                    "expressions", top_level_with.expressions + inner_with.expressions
496                )
497
498    return expression
499
500
501def ensure_bools(expression: exp.Expression) -> exp.Expression:
502    """Converts numeric values used in conditions into explicit boolean expressions."""
503    from sqlglot.optimizer.canonicalize import ensure_bools
504
505    def _ensure_bool(node: exp.Expression) -> None:
506        if (
507            node.is_number
508            or (
509                not isinstance(node, exp.SubqueryPredicate)
510                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
511            )
512            or (isinstance(node, exp.Column) and not node.type)
513        ):
514            node.replace(node.neq(0))
515
516    for node in expression.walk():
517        ensure_bools(node, _ensure_bool)
518
519    return expression
520
521
522def unqualify_columns(expression: exp.Expression) -> exp.Expression:
523    for column in expression.find_all(exp.Column):
524        # We only wanna pop off the table, db, catalog args
525        for part in column.parts[:-1]:
526            part.pop()
527
528    return expression
529
530
531def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
532    assert isinstance(expression, exp.Create)
533    for constraint in expression.find_all(exp.UniqueColumnConstraint):
534        if constraint.parent:
535            constraint.parent.pop()
536
537    return expression
538
539
540def ctas_with_tmp_tables_to_create_tmp_view(
541    expression: exp.Expression,
542    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
543) -> exp.Expression:
544    assert isinstance(expression, exp.Create)
545    properties = expression.args.get("properties")
546    temporary = any(
547        isinstance(prop, exp.TemporaryProperty)
548        for prop in (properties.expressions if properties else [])
549    )
550
551    # CTAS with temp tables map to CREATE TEMPORARY VIEW
552    if expression.kind == "TABLE" and temporary:
553        if expression.expression:
554            return exp.Create(
555                kind="TEMPORARY VIEW",
556                this=expression.this,
557                expression=expression.expression,
558            )
559        return tmp_storage_provider(expression)
560
561    return expression
562
563
564def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
565    """
566    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
567    PARTITIONED BY value is an array of column names, they are transformed into a schema.
568    The corresponding columns are removed from the create statement.
569    """
570    assert isinstance(expression, exp.Create)
571    has_schema = isinstance(expression.this, exp.Schema)
572    is_partitionable = expression.kind in {"TABLE", "VIEW"}
573
574    if has_schema and is_partitionable:
575        prop = expression.find(exp.PartitionedByProperty)
576        if prop and prop.this and not isinstance(prop.this, exp.Schema):
577            schema = expression.this
578            columns = {v.name.upper() for v in prop.this.expressions}
579            partitions = [col for col in schema.expressions if col.name.upper() in columns]
580            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
581            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
582            expression.set("this", schema)
583
584    return expression
585
586
587def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
588    """
589    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
590
591    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
592    """
593    assert isinstance(expression, exp.Create)
594    prop = expression.find(exp.PartitionedByProperty)
595    if (
596        prop
597        and prop.this
598        and isinstance(prop.this, exp.Schema)
599        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
600    ):
601        prop_this = exp.Tuple(
602            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
603        )
604        schema = expression.this
605        for e in prop.this.expressions:
606            schema.append("expressions", e)
607        prop.set("this", prop_this)
608
609    return expression
610
611
612def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
613    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
614    if isinstance(expression, exp.Struct):
615        expression.set(
616            "expressions",
617            [
618                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
619                for e in expression.expressions
620            ],
621        )
622
623    return expression
624
625
626def preprocess(
627    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
628) -> t.Callable[[Generator, exp.Expression], str]:
629    """
630    Creates a new transform by chaining a sequence of transformations and converts the resulting
631    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
632    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
633
634    Args:
635        transforms: sequence of transform functions. These will be called in order.
636
637    Returns:
638        Function that can be used as a generator transform.
639    """
640
641    def _to_sql(self, expression: exp.Expression) -> str:
642        expression_type = type(expression)
643
644        expression = transforms[0](expression)
645        for transform in transforms[1:]:
646            expression = transform(expression)
647
648        _sql_handler = getattr(self, expression.key + "_sql", None)
649        if _sql_handler:
650            return _sql_handler(expression)
651
652        transforms_handler = self.TRANSFORMS.get(type(expression))
653        if transforms_handler:
654            if expression_type is type(expression):
655                if isinstance(expression, exp.Func):
656                    return self.function_fallback_sql(expression)
657
658                # Ensures we don't enter an infinite loop. This can happen when the original expression
659                # has the same type as the final expression and there's no _sql method available for it,
660                # because then it'd re-enter _to_sql.
661                raise ValueError(
662                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
663                )
664
665            return transforms_handler(self, expression)
666
667        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
668
669    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 98    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 99    corresponding expression to avoid creating invalid column references.
100    """
101    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
102        taken = set(expression.named_selects)
103        for select in expression.selects:
104            if not select.alias_or_name:
105                alias = find_new_name(taken, "_c")
106                select.replace(exp.alias_(select, alias))
107                taken.add(alias)
108
109        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
110            alias_or_name = select.alias_or_name
111            identifier = select.args.get("alias") or select.this
112            if isinstance(identifier, exp.Identifier):
113                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
114            return alias_or_name
115
116        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
117        qualify_filters = expression.args["qualify"].pop().this
118        expression_by_alias = {
119            select.alias: select.this
120            for select in expression.selects
121            if isinstance(select, exp.Alias)
122        }
123
124        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
125        for select_candidate in qualify_filters.find_all(select_candidates):
126            if isinstance(select_candidate, exp.Window):
127                if expression_by_alias:
128                    for column in select_candidate.find_all(exp.Column):
129                        expr = expression_by_alias.get(column.name)
130                        if expr:
131                            column.replace(expr)
132
133                alias = find_new_name(expression.named_selects, "_w")
134                expression.select(exp.alias_(select_candidate, alias), copy=False)
135                column = exp.column(alias)
136
137                if isinstance(select_candidate.parent, exp.Qualify):
138                    qualify_filters = column
139                else:
140                    select_candidate.replace(column)
141            elif select_candidate.name not in expression.named_selects:
142                expression.select(select_candidate.copy(), copy=False)
143
144        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
145            qualify_filters, copy=False
146        )
147
148    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
151def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
152    """
153    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
154    other expressions. This transforms removes the precision from parameterized types in expressions.
155    """
156    for node in expression.find_all(exp.DataType):
157        node.set(
158            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
159        )
160
161    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
164def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
165    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
166    from sqlglot.optimizer.scope import find_all_in_scope
167
168    if isinstance(expression, exp.Select):
169        unnest_aliases = {
170            unnest.alias
171            for unnest in find_all_in_scope(expression, exp.Unnest)
172            if isinstance(unnest.parent, (exp.From, exp.Join))
173        }
174        if unnest_aliases:
175            for column in expression.find_all(exp.Column):
176                if column.table in unnest_aliases:
177                    column.set("table", None)
178                elif column.db in unnest_aliases:
179                    column.set("db", None)
180
181    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
184def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
185    """Convert cross join unnest into lateral view explode."""
186    if isinstance(expression, exp.Select):
187        for join in expression.args.get("joins") or []:
188            unnest = join.this
189
190            if isinstance(unnest, exp.Unnest):
191                alias = unnest.args.get("alias")
192                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
193
194                expression.args["joins"].remove(join)
195
196                for e, column in zip(unnest.expressions, alias.columns if alias else []):
197                    expression.append(
198                        "laterals",
199                        exp.Lateral(
200                            this=udtf(this=e),
201                            view=True,
202                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
203                        ),
204                    )
205
206    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
209def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
210    """Convert explode/posexplode into unnest."""
211
212    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
213        if isinstance(expression, exp.Select):
214            from sqlglot.optimizer.scope import Scope
215
216            taken_select_names = set(expression.named_selects)
217            taken_source_names = {name for name, _ in Scope(expression).references}
218
219            def new_name(names: t.Set[str], name: str) -> str:
220                name = find_new_name(names, name)
221                names.add(name)
222                return name
223
224            arrays: t.List[exp.Condition] = []
225            series_alias = new_name(taken_select_names, "pos")
226            series = exp.alias_(
227                exp.Unnest(
228                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
229                ),
230                new_name(taken_source_names, "_u"),
231                table=[series_alias],
232            )
233
234            # we use list here because expression.selects is mutated inside the loop
235            for select in list(expression.selects):
236                explode = select.find(exp.Explode)
237
238                if explode:
239                    pos_alias = ""
240                    explode_alias = ""
241
242                    if isinstance(select, exp.Alias):
243                        explode_alias = select.args["alias"]
244                        alias = select
245                    elif isinstance(select, exp.Aliases):
246                        pos_alias = select.aliases[0]
247                        explode_alias = select.aliases[1]
248                        alias = select.replace(exp.alias_(select.this, "", copy=False))
249                    else:
250                        alias = select.replace(exp.alias_(select, ""))
251                        explode = alias.find(exp.Explode)
252                        assert explode
253
254                    is_posexplode = isinstance(explode, exp.Posexplode)
255                    explode_arg = explode.this
256
257                    if isinstance(explode, exp.ExplodeOuter):
258                        bracket = explode_arg[0]
259                        bracket.set("safe", True)
260                        bracket.set("offset", True)
261                        explode_arg = exp.func(
262                            "IF",
263                            exp.func(
264                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
265                            ).eq(0),
266                            exp.array(bracket, copy=False),
267                            explode_arg,
268                        )
269
270                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
271                    if isinstance(explode_arg, exp.Column):
272                        taken_select_names.add(explode_arg.output_name)
273
274                    unnest_source_alias = new_name(taken_source_names, "_u")
275
276                    if not explode_alias:
277                        explode_alias = new_name(taken_select_names, "col")
278
279                        if is_posexplode:
280                            pos_alias = new_name(taken_select_names, "pos")
281
282                    if not pos_alias:
283                        pos_alias = new_name(taken_select_names, "pos")
284
285                    alias.set("alias", exp.to_identifier(explode_alias))
286
287                    series_table_alias = series.args["alias"].this
288                    column = exp.If(
289                        this=exp.column(series_alias, table=series_table_alias).eq(
290                            exp.column(pos_alias, table=unnest_source_alias)
291                        ),
292                        true=exp.column(explode_alias, table=unnest_source_alias),
293                    )
294
295                    explode.replace(column)
296
297                    if is_posexplode:
298                        expressions = expression.expressions
299                        expressions.insert(
300                            expressions.index(alias) + 1,
301                            exp.If(
302                                this=exp.column(series_alias, table=series_table_alias).eq(
303                                    exp.column(pos_alias, table=unnest_source_alias)
304                                ),
305                                true=exp.column(pos_alias, table=unnest_source_alias),
306                            ).as_(pos_alias),
307                        )
308                        expression.set("expressions", expressions)
309
310                    if not arrays:
311                        if expression.args.get("from"):
312                            expression.join(series, copy=False, join_type="CROSS")
313                        else:
314                            expression.from_(series, copy=False)
315
316                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
317                    arrays.append(size)
318
319                    # trino doesn't support left join unnest with on conditions
320                    # if it did, this would be much simpler
321                    expression.join(
322                        exp.alias_(
323                            exp.Unnest(
324                                expressions=[explode_arg.copy()],
325                                offset=exp.to_identifier(pos_alias),
326                            ),
327                            unnest_source_alias,
328                            table=[explode_alias],
329                        ),
330                        join_type="CROSS",
331                        copy=False,
332                    )
333
334                    if index_offset != 1:
335                        size = size - 1
336
337                    expression.where(
338                        exp.column(series_alias, table=series_table_alias)
339                        .eq(exp.column(pos_alias, table=unnest_source_alias))
340                        .or_(
341                            (exp.column(series_alias, table=series_table_alias) > size).and_(
342                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
343                            )
344                        ),
345                        copy=False,
346                    )
347
348            if arrays:
349                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
350
351                if index_offset != 1:
352                    end = end - (1 - index_offset)
353                series.expressions[0].set("end", end)
354
355        return expression
356
357    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
360def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
361    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
362    if (
363        isinstance(expression, exp.PERCENTILES)
364        and not isinstance(expression.parent, exp.WithinGroup)
365        and expression.expression
366    ):
367        column = expression.this.pop()
368        expression.set("this", expression.expression.pop())
369        order = exp.Order(expressions=[exp.Ordered(this=column)])
370        expression = exp.WithinGroup(this=expression, expression=order)
371
372    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
375def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
376    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
377    if (
378        isinstance(expression, exp.WithinGroup)
379        and isinstance(expression.this, exp.PERCENTILES)
380        and isinstance(expression.expression, exp.Order)
381    ):
382        quantile = expression.this.this
383        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
384        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
385
386    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
389def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
390    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
391    if isinstance(expression, exp.With) and expression.recursive:
392        next_name = name_sequence("_c_")
393
394        for cte in expression.expressions:
395            if not cte.args["alias"].columns:
396                query = cte.this
397                if isinstance(query, exp.Union):
398                    query = query.this
399
400                cte.args["alias"].set(
401                    "columns",
402                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
403                )
404
405    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
408def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
409    """Replace 'epoch' in casts by the equivalent date literal."""
410    if (
411        isinstance(expression, (exp.Cast, exp.TryCast))
412        and expression.name.lower() == "epoch"
413        and expression.to.this in exp.DataType.TEMPORAL_TYPES
414    ):
415        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
416
417    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
420def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
421    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
422    if isinstance(expression, exp.Select):
423        for join in expression.args.get("joins") or []:
424            on = join.args.get("on")
425            if on and join.kind in ("SEMI", "ANTI"):
426                subquery = exp.select("1").from_(join.this).where(on)
427                exists = exp.Exists(this=subquery)
428                if join.kind == "ANTI":
429                    exists = exists.not_(copy=False)
430
431                join.pop()
432                expression.where(exists, copy=False)
433
434    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
437def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
438    """
439    Converts a query with a FULL OUTER join to a union of identical queries that
440    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
441    for queries that have a single FULL OUTER join.
442    """
443    if isinstance(expression, exp.Select):
444        full_outer_joins = [
445            (index, join)
446            for index, join in enumerate(expression.args.get("joins") or [])
447            if join.side == "FULL"
448        ]
449
450        if len(full_outer_joins) == 1:
451            expression_copy = expression.copy()
452            expression.set("limit", None)
453            index, full_outer_join = full_outer_joins[0]
454            full_outer_join.set("side", "left")
455            expression_copy.args["joins"][index].set("side", "right")
456            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
457
458            return exp.union(expression, expression_copy, copy=False)
459
460    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
463def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
464    """
465    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
466    defined at the top-level, so for example queries like:
467
468        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
469
470    are invalid in those dialects. This transformation can be used to ensure all CTEs are
471    moved to the top level so that the final SQL code is valid from a syntax standpoint.
472
473    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
474    """
475    top_level_with = expression.args.get("with")
476    for inner_with in expression.find_all(exp.With):
477        if inner_with.parent is expression:
478            continue
479
480        if not top_level_with:
481            top_level_with = inner_with.pop()
482            expression.set("with", top_level_with)
483        else:
484            if inner_with.recursive:
485                top_level_with.set("recursive", True)
486
487            parent_cte = inner_with.find_ancestor(exp.CTE)
488            inner_with.pop()
489
490            if parent_cte:
491                i = top_level_with.expressions.index(parent_cte)
492                top_level_with.expressions[i:i] = inner_with.expressions
493                top_level_with.set("expressions", top_level_with.expressions)
494            else:
495                top_level_with.set(
496                    "expressions", top_level_with.expressions + inner_with.expressions
497                )
498
499    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
502def ensure_bools(expression: exp.Expression) -> exp.Expression:
503    """Converts numeric values used in conditions into explicit boolean expressions."""
504    from sqlglot.optimizer.canonicalize import ensure_bools
505
506    def _ensure_bool(node: exp.Expression) -> None:
507        if (
508            node.is_number
509            or (
510                not isinstance(node, exp.SubqueryPredicate)
511                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
512            )
513            or (isinstance(node, exp.Column) and not node.type)
514        ):
515            node.replace(node.neq(0))
516
517    for node in expression.walk():
518        ensure_bools(node, _ensure_bool)
519
520    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
523def unqualify_columns(expression: exp.Expression) -> exp.Expression:
524    for column in expression.find_all(exp.Column):
525        # We only wanna pop off the table, db, catalog args
526        for part in column.parts[:-1]:
527            part.pop()
528
529    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
532def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
533    assert isinstance(expression, exp.Create)
534    for constraint in expression.find_all(exp.UniqueColumnConstraint):
535        if constraint.parent:
536            constraint.parent.pop()
537
538    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
541def ctas_with_tmp_tables_to_create_tmp_view(
542    expression: exp.Expression,
543    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
544) -> exp.Expression:
545    assert isinstance(expression, exp.Create)
546    properties = expression.args.get("properties")
547    temporary = any(
548        isinstance(prop, exp.TemporaryProperty)
549        for prop in (properties.expressions if properties else [])
550    )
551
552    # CTAS with temp tables map to CREATE TEMPORARY VIEW
553    if expression.kind == "TABLE" and temporary:
554        if expression.expression:
555            return exp.Create(
556                kind="TEMPORARY VIEW",
557                this=expression.this,
558                expression=expression.expression,
559            )
560        return tmp_storage_provider(expression)
561
562    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
565def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
566    """
567    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
568    PARTITIONED BY value is an array of column names, they are transformed into a schema.
569    The corresponding columns are removed from the create statement.
570    """
571    assert isinstance(expression, exp.Create)
572    has_schema = isinstance(expression.this, exp.Schema)
573    is_partitionable = expression.kind in {"TABLE", "VIEW"}
574
575    if has_schema and is_partitionable:
576        prop = expression.find(exp.PartitionedByProperty)
577        if prop and prop.this and not isinstance(prop.this, exp.Schema):
578            schema = expression.this
579            columns = {v.name.upper() for v in prop.this.expressions}
580            partitions = [col for col in schema.expressions if col.name.upper() in columns]
581            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
582            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
583            expression.set("this", schema)
584
585    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
588def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
589    """
590    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
591
592    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
593    """
594    assert isinstance(expression, exp.Create)
595    prop = expression.find(exp.PartitionedByProperty)
596    if (
597        prop
598        and prop.this
599        and isinstance(prop.this, exp.Schema)
600        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
601    ):
602        prop_this = exp.Tuple(
603            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
604        )
605        schema = expression.this
606        for e in prop.this.expressions:
607            schema.append("expressions", e)
608        prop.set("this", prop_this)
609
610    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
613def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
614    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
615    if isinstance(expression, exp.Struct):
616        expression.set(
617            "expressions",
618            [
619                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
620                for e in expression.expressions
621            ],
622        )
623
624    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
627def preprocess(
628    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
629) -> t.Callable[[Generator, exp.Expression], str]:
630    """
631    Creates a new transform by chaining a sequence of transformations and converts the resulting
632    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
633    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
634
635    Args:
636        transforms: sequence of transform functions. These will be called in order.
637
638    Returns:
639        Function that can be used as a generator transform.
640    """
641
642    def _to_sql(self, expression: exp.Expression) -> str:
643        expression_type = type(expression)
644
645        expression = transforms[0](expression)
646        for transform in transforms[1:]:
647            expression = transform(expression)
648
649        _sql_handler = getattr(self, expression.key + "_sql", None)
650        if _sql_handler:
651            return _sql_handler(expression)
652
653        transforms_handler = self.TRANSFORMS.get(type(expression))
654        if transforms_handler:
655            if expression_type is type(expression):
656                if isinstance(expression, exp.Func):
657                    return self.function_fallback_sql(expression)
658
659                # Ensures we don't enter an infinite loop. This can happen when the original expression
660                # has the same type as the final expression and there's no _sql method available for it,
661                # because then it'd re-enter _to_sql.
662                raise ValueError(
663                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
664                )
665
666            return transforms_handler(self, expression)
667
668        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
669
670    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.