Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28UNESCAPED_SEQUENCES = {
  29    "\\a": "\a",
  30    "\\b": "\b",
  31    "\\f": "\f",
  32    "\\n": "\n",
  33    "\\r": "\r",
  34    "\\t": "\t",
  35    "\\v": "\v",
  36    "\\\\": "\\",
  37}
  38
  39
  40class Dialects(str, Enum):
  41    """Dialects supported by SQLGLot."""
  42
  43    DIALECT = ""
  44
  45    ATHENA = "athena"
  46    BIGQUERY = "bigquery"
  47    CLICKHOUSE = "clickhouse"
  48    DATABRICKS = "databricks"
  49    DORIS = "doris"
  50    DRILL = "drill"
  51    DUCKDB = "duckdb"
  52    HIVE = "hive"
  53    MATERIALIZE = "materialize"
  54    MYSQL = "mysql"
  55    ORACLE = "oracle"
  56    POSTGRES = "postgres"
  57    PRESTO = "presto"
  58    PRQL = "prql"
  59    REDSHIFT = "redshift"
  60    RISINGWAVE = "risingwave"
  61    SNOWFLAKE = "snowflake"
  62    SPARK = "spark"
  63    SPARK2 = "spark2"
  64    SQLITE = "sqlite"
  65    STARROCKS = "starrocks"
  66    TABLEAU = "tableau"
  67    TERADATA = "teradata"
  68    TRINO = "trino"
  69    TSQL = "tsql"
  70
  71
  72class NormalizationStrategy(str, AutoName):
  73    """Specifies the strategy according to which identifiers should be normalized."""
  74
  75    LOWERCASE = auto()
  76    """Unquoted identifiers are lowercased."""
  77
  78    UPPERCASE = auto()
  79    """Unquoted identifiers are uppercased."""
  80
  81    CASE_SENSITIVE = auto()
  82    """Always case-sensitive, regardless of quotes."""
  83
  84    CASE_INSENSITIVE = auto()
  85    """Always case-insensitive, regardless of quotes."""
  86
  87
  88class _Dialect(type):
  89    classes: t.Dict[str, t.Type[Dialect]] = {}
  90
  91    def __eq__(cls, other: t.Any) -> bool:
  92        if cls is other:
  93            return True
  94        if isinstance(other, str):
  95            return cls is cls.get(other)
  96        if isinstance(other, Dialect):
  97            return cls is type(other)
  98
  99        return False
 100
 101    def __hash__(cls) -> int:
 102        return hash(cls.__name__.lower())
 103
 104    @classmethod
 105    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 106        return cls.classes[key]
 107
 108    @classmethod
 109    def get(
 110        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 111    ) -> t.Optional[t.Type[Dialect]]:
 112        return cls.classes.get(key, default)
 113
 114    def __new__(cls, clsname, bases, attrs):
 115        klass = super().__new__(cls, clsname, bases, attrs)
 116        enum = Dialects.__members__.get(clsname.upper())
 117        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 118
 119        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 120        klass.FORMAT_TRIE = (
 121            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 122        )
 123        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 124        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 125
 126        base = seq_get(bases, 0)
 127        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 128        base_parser = (getattr(base, "parser_class", Parser),)
 129        base_generator = (getattr(base, "generator_class", Generator),)
 130
 131        klass.tokenizer_class = klass.__dict__.get(
 132            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 133        )
 134        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 135        klass.generator_class = klass.__dict__.get(
 136            "Generator", type("Generator", base_generator, {})
 137        )
 138
 139        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 140        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 141            klass.tokenizer_class._IDENTIFIERS.items()
 142        )[0]
 143
 144        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 145            return next(
 146                (
 147                    (s, e)
 148                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 149                    if t == token_type
 150                ),
 151                (None, None),
 152            )
 153
 154        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 155        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 156        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 157        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 158
 159        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 160            klass.UNESCAPED_SEQUENCES = {
 161                **UNESCAPED_SEQUENCES,
 162                **klass.UNESCAPED_SEQUENCES,
 163            }
 164
 165        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 166
 167        if enum not in ("", "bigquery"):
 168            klass.generator_class.SELECT_KINDS = ()
 169
 170        if enum not in ("", "athena", "presto", "trino"):
 171            klass.generator_class.TRY_SUPPORTED = False
 172            klass.generator_class.SUPPORTS_UESCAPE = False
 173
 174        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 175            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 176            for modifier in ("cluster", "distribute", "sort"):
 177                modifier_transforms.pop(modifier, None)
 178
 179            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 180
 181        if enum not in ("", "doris", "mysql"):
 182            klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
 183                TokenType.STRAIGHT_JOIN,
 184            }
 185            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 186                TokenType.STRAIGHT_JOIN,
 187            }
 188
 189        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 190            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 191                TokenType.ANTI,
 192                TokenType.SEMI,
 193            }
 194
 195        return klass
 196
 197
 198class Dialect(metaclass=_Dialect):
 199    INDEX_OFFSET = 0
 200    """The base index offset for arrays."""
 201
 202    WEEK_OFFSET = 0
 203    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 204
 205    UNNEST_COLUMN_ONLY = False
 206    """Whether `UNNEST` table aliases are treated as column aliases."""
 207
 208    ALIAS_POST_TABLESAMPLE = False
 209    """Whether the table alias comes after tablesample."""
 210
 211    TABLESAMPLE_SIZE_IS_PERCENT = False
 212    """Whether a size in the table sample clause represents percentage."""
 213
 214    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 215    """Specifies the strategy according to which identifiers should be normalized."""
 216
 217    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 218    """Whether an unquoted identifier can start with a digit."""
 219
 220    DPIPE_IS_STRING_CONCAT = True
 221    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 222
 223    STRICT_STRING_CONCAT = False
 224    """Whether `CONCAT`'s arguments must be strings."""
 225
 226    SUPPORTS_USER_DEFINED_TYPES = True
 227    """Whether user-defined data types are supported."""
 228
 229    SUPPORTS_SEMI_ANTI_JOIN = True
 230    """Whether `SEMI` or `ANTI` joins are supported."""
 231
 232    SUPPORTS_COLUMN_JOIN_MARKS = False
 233    """Whether the old-style outer join (+) syntax is supported."""
 234
 235    NORMALIZE_FUNCTIONS: bool | str = "upper"
 236    """
 237    Determines how function names are going to be normalized.
 238    Possible values:
 239        "upper" or True: Convert names to uppercase.
 240        "lower": Convert names to lowercase.
 241        False: Disables function name normalization.
 242    """
 243
 244    LOG_BASE_FIRST: t.Optional[bool] = True
 245    """
 246    Whether the base comes first in the `LOG` function.
 247    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 248    """
 249
 250    NULL_ORDERING = "nulls_are_small"
 251    """
 252    Default `NULL` ordering method to use if not explicitly set.
 253    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 254    """
 255
 256    TYPED_DIVISION = False
 257    """
 258    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 259    False means `a / b` is always float division.
 260    True means `a / b` is integer division if both `a` and `b` are integers.
 261    """
 262
 263    SAFE_DIVISION = False
 264    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 265
 266    CONCAT_COALESCE = False
 267    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 268
 269    HEX_LOWERCASE = False
 270    """Whether the `HEX` function returns a lowercase hexadecimal string."""
 271
 272    DATE_FORMAT = "'%Y-%m-%d'"
 273    DATEINT_FORMAT = "'%Y%m%d'"
 274    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 275
 276    TIME_MAPPING: t.Dict[str, str] = {}
 277    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 278
 279    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 280    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 281    FORMAT_MAPPING: t.Dict[str, str] = {}
 282    """
 283    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 284    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 285    """
 286
 287    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 288    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 289
 290    PSEUDOCOLUMNS: t.Set[str] = set()
 291    """
 292    Columns that are auto-generated by the engine corresponding to this dialect.
 293    For example, such columns may be excluded from `SELECT *` queries.
 294    """
 295
 296    PREFER_CTE_ALIAS_COLUMN = False
 297    """
 298    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 299    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 300    any projection aliases in the subquery.
 301
 302    For example,
 303        WITH y(c) AS (
 304            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 305        ) SELECT c FROM y;
 306
 307        will be rewritten as
 308
 309        WITH y(c) AS (
 310            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 311        ) SELECT c FROM y;
 312    """
 313
 314    # --- Autofilled ---
 315
 316    tokenizer_class = Tokenizer
 317    parser_class = Parser
 318    generator_class = Generator
 319
 320    # A trie of the time_mapping keys
 321    TIME_TRIE: t.Dict = {}
 322    FORMAT_TRIE: t.Dict = {}
 323
 324    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 325    INVERSE_TIME_TRIE: t.Dict = {}
 326
 327    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 328
 329    # Delimiters for string literals and identifiers
 330    QUOTE_START = "'"
 331    QUOTE_END = "'"
 332    IDENTIFIER_START = '"'
 333    IDENTIFIER_END = '"'
 334
 335    # Delimiters for bit, hex, byte and unicode literals
 336    BIT_START: t.Optional[str] = None
 337    BIT_END: t.Optional[str] = None
 338    HEX_START: t.Optional[str] = None
 339    HEX_END: t.Optional[str] = None
 340    BYTE_START: t.Optional[str] = None
 341    BYTE_END: t.Optional[str] = None
 342    UNICODE_START: t.Optional[str] = None
 343    UNICODE_END: t.Optional[str] = None
 344
 345    # Separator of COPY statement parameters
 346    COPY_PARAMS_ARE_CSV = True
 347
 348    @classmethod
 349    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 350        """
 351        Look up a dialect in the global dialect registry and return it if it exists.
 352
 353        Args:
 354            dialect: The target dialect. If this is a string, it can be optionally followed by
 355                additional key-value pairs that are separated by commas and are used to specify
 356                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 357
 358        Example:
 359            >>> dialect = dialect_class = get_or_raise("duckdb")
 360            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 361
 362        Returns:
 363            The corresponding Dialect instance.
 364        """
 365
 366        if not dialect:
 367            return cls()
 368        if isinstance(dialect, _Dialect):
 369            return dialect()
 370        if isinstance(dialect, Dialect):
 371            return dialect
 372        if isinstance(dialect, str):
 373            try:
 374                dialect_name, *kv_pairs = dialect.split(",")
 375                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 376            except ValueError:
 377                raise ValueError(
 378                    f"Invalid dialect format: '{dialect}'. "
 379                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 380                )
 381
 382            result = cls.get(dialect_name.strip())
 383            if not result:
 384                from difflib import get_close_matches
 385
 386                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 387                if similar:
 388                    similar = f" Did you mean {similar}?"
 389
 390                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 391
 392            return result(**kwargs)
 393
 394        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 395
 396    @classmethod
 397    def format_time(
 398        cls, expression: t.Optional[str | exp.Expression]
 399    ) -> t.Optional[exp.Expression]:
 400        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 401        if isinstance(expression, str):
 402            return exp.Literal.string(
 403                # the time formats are quoted
 404                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 405            )
 406
 407        if expression and expression.is_string:
 408            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 409
 410        return expression
 411
 412    def __init__(self, **kwargs) -> None:
 413        normalization_strategy = kwargs.get("normalization_strategy")
 414
 415        if normalization_strategy is None:
 416            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 417        else:
 418            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 419
 420    def __eq__(self, other: t.Any) -> bool:
 421        # Does not currently take dialect state into account
 422        return type(self) == other
 423
 424    def __hash__(self) -> int:
 425        # Does not currently take dialect state into account
 426        return hash(type(self))
 427
 428    def normalize_identifier(self, expression: E) -> E:
 429        """
 430        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 431
 432        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 433        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 434        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 435        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 436
 437        There are also dialects like Spark, which are case-insensitive even when quotes are
 438        present, and dialects like MySQL, whose resolution rules match those employed by the
 439        underlying operating system, for example they may always be case-sensitive in Linux.
 440
 441        Finally, the normalization behavior of some engines can even be controlled through flags,
 442        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 443
 444        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 445        that it can analyze queries in the optimizer and successfully capture their semantics.
 446        """
 447        if (
 448            isinstance(expression, exp.Identifier)
 449            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 450            and (
 451                not expression.quoted
 452                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 453            )
 454        ):
 455            expression.set(
 456                "this",
 457                (
 458                    expression.this.upper()
 459                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 460                    else expression.this.lower()
 461                ),
 462            )
 463
 464        return expression
 465
 466    def case_sensitive(self, text: str) -> bool:
 467        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 468        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 469            return False
 470
 471        unsafe = (
 472            str.islower
 473            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 474            else str.isupper
 475        )
 476        return any(unsafe(char) for char in text)
 477
 478    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 479        """Checks if text can be identified given an identify option.
 480
 481        Args:
 482            text: The text to check.
 483            identify:
 484                `"always"` or `True`: Always returns `True`.
 485                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 486
 487        Returns:
 488            Whether the given text can be identified.
 489        """
 490        if identify is True or identify == "always":
 491            return True
 492
 493        if identify == "safe":
 494            return not self.case_sensitive(text)
 495
 496        return False
 497
 498    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 499        """
 500        Adds quotes to a given identifier.
 501
 502        Args:
 503            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 504            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 505                "unsafe", with respect to its characters and this dialect's normalization strategy.
 506        """
 507        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 508            name = expression.this
 509            expression.set(
 510                "quoted",
 511                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 512            )
 513
 514        return expression
 515
 516    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 517        if isinstance(path, exp.Literal):
 518            path_text = path.name
 519            if path.is_number:
 520                path_text = f"[{path_text}]"
 521
 522            try:
 523                return parse_json_path(path_text)
 524            except ParseError as e:
 525                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 526
 527        return path
 528
 529    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 530        return self.parser(**opts).parse(self.tokenize(sql), sql)
 531
 532    def parse_into(
 533        self, expression_type: exp.IntoType, sql: str, **opts
 534    ) -> t.List[t.Optional[exp.Expression]]:
 535        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 536
 537    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 538        return self.generator(**opts).generate(expression, copy=copy)
 539
 540    def transpile(self, sql: str, **opts) -> t.List[str]:
 541        return [
 542            self.generate(expression, copy=False, **opts) if expression else ""
 543            for expression in self.parse(sql)
 544        ]
 545
 546    def tokenize(self, sql: str) -> t.List[Token]:
 547        return self.tokenizer.tokenize(sql)
 548
 549    @property
 550    def tokenizer(self) -> Tokenizer:
 551        if not hasattr(self, "_tokenizer"):
 552            self._tokenizer = self.tokenizer_class(dialect=self)
 553        return self._tokenizer
 554
 555    def parser(self, **opts) -> Parser:
 556        return self.parser_class(dialect=self, **opts)
 557
 558    def generator(self, **opts) -> Generator:
 559        return self.generator_class(dialect=self, **opts)
 560
 561
 562DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 563
 564
 565def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 566    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 567
 568
 569def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 570    if expression.args.get("accuracy"):
 571        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 572    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 573
 574
 575def if_sql(
 576    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 577) -> t.Callable[[Generator, exp.If], str]:
 578    def _if_sql(self: Generator, expression: exp.If) -> str:
 579        return self.func(
 580            name,
 581            expression.this,
 582            expression.args.get("true"),
 583            expression.args.get("false") or false_value,
 584        )
 585
 586    return _if_sql
 587
 588
 589def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 590    this = expression.this
 591    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 592        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 593
 594    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 595
 596
 597def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 598    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 599
 600
 601def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 602    elem = seq_get(expression.expressions, 0)
 603    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 604        return self.func("ARRAY", elem)
 605    return inline_array_sql(self, expression)
 606
 607
 608def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 609    return self.like_sql(
 610        exp.Like(
 611            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
 612        )
 613    )
 614
 615
 616def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 617    zone = self.sql(expression, "this")
 618    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 619
 620
 621def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 622    if expression.args.get("recursive"):
 623        self.unsupported("Recursive CTEs are unsupported")
 624        expression.args["recursive"] = False
 625    return self.with_sql(expression)
 626
 627
 628def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 629    n = self.sql(expression, "this")
 630    d = self.sql(expression, "expression")
 631    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 632
 633
 634def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 635    self.unsupported("TABLESAMPLE unsupported")
 636    return self.sql(expression.this)
 637
 638
 639def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 640    self.unsupported("PIVOT unsupported")
 641    return ""
 642
 643
 644def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 645    return self.cast_sql(expression)
 646
 647
 648def no_comment_column_constraint_sql(
 649    self: Generator, expression: exp.CommentColumnConstraint
 650) -> str:
 651    self.unsupported("CommentColumnConstraint unsupported")
 652    return ""
 653
 654
 655def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 656    self.unsupported("MAP_FROM_ENTRIES unsupported")
 657    return ""
 658
 659
 660def str_position_sql(
 661    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 662) -> str:
 663    this = self.sql(expression, "this")
 664    substr = self.sql(expression, "substr")
 665    position = self.sql(expression, "position")
 666    instance = expression.args.get("instance") if generate_instance else None
 667    position_offset = ""
 668
 669    if position:
 670        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 671        this = self.func("SUBSTR", this, position)
 672        position_offset = f" + {position} - 1"
 673
 674    return self.func("STRPOS", this, substr, instance) + position_offset
 675
 676
 677def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 678    return (
 679        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 680    )
 681
 682
 683def var_map_sql(
 684    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 685) -> str:
 686    keys = expression.args["keys"]
 687    values = expression.args["values"]
 688
 689    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 690        self.unsupported("Cannot convert array columns into map.")
 691        return self.func(map_func_name, keys, values)
 692
 693    args = []
 694    for key, value in zip(keys.expressions, values.expressions):
 695        args.append(self.sql(key))
 696        args.append(self.sql(value))
 697
 698    return self.func(map_func_name, *args)
 699
 700
 701def build_formatted_time(
 702    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 703) -> t.Callable[[t.List], E]:
 704    """Helper used for time expressions.
 705
 706    Args:
 707        exp_class: the expression class to instantiate.
 708        dialect: target sql dialect.
 709        default: the default format, True being time.
 710
 711    Returns:
 712        A callable that can be used to return the appropriately formatted time expression.
 713    """
 714
 715    def _builder(args: t.List):
 716        return exp_class(
 717            this=seq_get(args, 0),
 718            format=Dialect[dialect].format_time(
 719                seq_get(args, 1)
 720                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 721            ),
 722        )
 723
 724    return _builder
 725
 726
 727def time_format(
 728    dialect: DialectType = None,
 729) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 730    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 731        """
 732        Returns the time format for a given expression, unless it's equivalent
 733        to the default time format of the dialect of interest.
 734        """
 735        time_format = self.format_time(expression)
 736        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 737
 738    return _time_format
 739
 740
 741def build_date_delta(
 742    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 743) -> t.Callable[[t.List], E]:
 744    def _builder(args: t.List) -> E:
 745        unit_based = len(args) == 3
 746        this = args[2] if unit_based else seq_get(args, 0)
 747        unit = args[0] if unit_based else exp.Literal.string("DAY")
 748        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 749        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 750
 751    return _builder
 752
 753
 754def build_date_delta_with_interval(
 755    expression_class: t.Type[E],
 756) -> t.Callable[[t.List], t.Optional[E]]:
 757    def _builder(args: t.List) -> t.Optional[E]:
 758        if len(args) < 2:
 759            return None
 760
 761        interval = args[1]
 762
 763        if not isinstance(interval, exp.Interval):
 764            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 765
 766        expression = interval.this
 767        if expression and expression.is_string:
 768            expression = exp.Literal.number(expression.this)
 769
 770        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 771
 772    return _builder
 773
 774
 775def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 776    unit = seq_get(args, 0)
 777    this = seq_get(args, 1)
 778
 779    if isinstance(this, exp.Cast) and this.is_type("date"):
 780        return exp.DateTrunc(unit=unit, this=this)
 781    return exp.TimestampTrunc(this=this, unit=unit)
 782
 783
 784def date_add_interval_sql(
 785    data_type: str, kind: str
 786) -> t.Callable[[Generator, exp.Expression], str]:
 787    def func(self: Generator, expression: exp.Expression) -> str:
 788        this = self.sql(expression, "this")
 789        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 790        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 791
 792    return func
 793
 794
 795def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
 796    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 797        args = [unit_to_str(expression), expression.this]
 798        if zone:
 799            args.append(expression.args.get("zone"))
 800        return self.func("DATE_TRUNC", *args)
 801
 802    return _timestamptrunc_sql
 803
 804
 805def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 806    if not expression.expression:
 807        from sqlglot.optimizer.annotate_types import annotate_types
 808
 809        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 810        return self.sql(exp.cast(expression.this, target_type))
 811    if expression.text("expression").lower() in TIMEZONES:
 812        return self.sql(
 813            exp.AtTimeZone(
 814                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 815                zone=expression.expression,
 816            )
 817        )
 818    return self.func("TIMESTAMP", expression.this, expression.expression)
 819
 820
 821def locate_to_strposition(args: t.List) -> exp.Expression:
 822    return exp.StrPosition(
 823        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 824    )
 825
 826
 827def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 828    return self.func(
 829        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 830    )
 831
 832
 833def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 834    return self.sql(
 835        exp.Substring(
 836            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 837        )
 838    )
 839
 840
 841def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 842    return self.sql(
 843        exp.Substring(
 844            this=expression.this,
 845            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 846        )
 847    )
 848
 849
 850def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 851    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 852
 853
 854def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 855    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 856
 857
 858# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 859def encode_decode_sql(
 860    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 861) -> str:
 862    charset = expression.args.get("charset")
 863    if charset and charset.name.lower() != "utf-8":
 864        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 865
 866    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 867
 868
 869def min_or_least(self: Generator, expression: exp.Min) -> str:
 870    name = "LEAST" if expression.expressions else "MIN"
 871    return rename_func(name)(self, expression)
 872
 873
 874def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 875    name = "GREATEST" if expression.expressions else "MAX"
 876    return rename_func(name)(self, expression)
 877
 878
 879def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 880    cond = expression.this
 881
 882    if isinstance(expression.this, exp.Distinct):
 883        cond = expression.this.expressions[0]
 884        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 885
 886    return self.func("sum", exp.func("if", cond, 1, 0))
 887
 888
 889def trim_sql(self: Generator, expression: exp.Trim) -> str:
 890    target = self.sql(expression, "this")
 891    trim_type = self.sql(expression, "position")
 892    remove_chars = self.sql(expression, "expression")
 893    collation = self.sql(expression, "collation")
 894
 895    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 896    if not remove_chars and not collation:
 897        return self.trim_sql(expression)
 898
 899    trim_type = f"{trim_type} " if trim_type else ""
 900    remove_chars = f"{remove_chars} " if remove_chars else ""
 901    from_part = "FROM " if trim_type or remove_chars else ""
 902    collation = f" COLLATE {collation}" if collation else ""
 903    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 904
 905
 906def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 907    return self.func("STRPTIME", expression.this, self.format_time(expression))
 908
 909
 910def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 911    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 912
 913
 914def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 915    delim, *rest_args = expression.expressions
 916    return self.sql(
 917        reduce(
 918            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 919            rest_args,
 920        )
 921    )
 922
 923
 924def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 925    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 926    if bad_args:
 927        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 928
 929    return self.func(
 930        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 931    )
 932
 933
 934def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 935    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
 936    if bad_args:
 937        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 938
 939    return self.func(
 940        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 941    )
 942
 943
 944def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 945    names = []
 946    for agg in aggregations:
 947        if isinstance(agg, exp.Alias):
 948            names.append(agg.alias)
 949        else:
 950            """
 951            This case corresponds to aggregations without aliases being used as suffixes
 952            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 953            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 954            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 955            """
 956            agg_all_unquoted = agg.transform(
 957                lambda node: (
 958                    exp.Identifier(this=node.name, quoted=False)
 959                    if isinstance(node, exp.Identifier)
 960                    else node
 961                )
 962            )
 963            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 964
 965    return names
 966
 967
 968def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 969    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 970
 971
 972# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 973def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 974    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 975
 976
 977def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 978    return self.func("MAX", expression.this)
 979
 980
 981def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 982    a = self.sql(expression.left)
 983    b = self.sql(expression.right)
 984    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 985
 986
 987def is_parse_json(expression: exp.Expression) -> bool:
 988    return isinstance(expression, exp.ParseJSON) or (
 989        isinstance(expression, exp.Cast) and expression.is_type("json")
 990    )
 991
 992
 993def isnull_to_is_null(args: t.List) -> exp.Expression:
 994    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 995
 996
 997def generatedasidentitycolumnconstraint_sql(
 998    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 999) -> str:
1000    start = self.sql(expression, "start") or "1"
1001    increment = self.sql(expression, "increment") or "1"
1002    return f"IDENTITY({start}, {increment})"
1003
1004
1005def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
1006    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
1007        if expression.args.get("count"):
1008            self.unsupported(f"Only two arguments are supported in function {name}.")
1009
1010        return self.func(name, expression.this, expression.expression)
1011
1012    return _arg_max_or_min_sql
1013
1014
1015def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1016    this = expression.this.copy()
1017
1018    return_type = expression.return_type
1019    if return_type.is_type(exp.DataType.Type.DATE):
1020        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1021        # can truncate timestamp strings, because some dialects can't cast them to DATE
1022        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1023
1024    expression.this.replace(exp.cast(this, return_type))
1025    return expression
1026
1027
1028def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1029    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1030        if cast and isinstance(expression, exp.TsOrDsAdd):
1031            expression = ts_or_ds_add_cast(expression)
1032
1033        return self.func(
1034            name,
1035            unit_to_var(expression),
1036            expression.expression,
1037            expression.this,
1038        )
1039
1040    return _delta_sql
1041
1042
1043def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1044    unit = expression.args.get("unit")
1045
1046    if isinstance(unit, exp.Placeholder):
1047        return unit
1048    if unit:
1049        return exp.Literal.string(unit.name)
1050    return exp.Literal.string(default) if default else None
1051
1052
1053def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1054    unit = expression.args.get("unit")
1055
1056    if isinstance(unit, (exp.Var, exp.Placeholder)):
1057        return unit
1058    return exp.Var(this=default) if default else None
1059
1060
1061def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1062    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1063    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1064    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1065
1066    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1067
1068
1069def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1070    """Remove table refs from columns in when statements."""
1071    alias = expression.this.args.get("alias")
1072
1073    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1074        return self.dialect.normalize_identifier(identifier).name if identifier else None
1075
1076    targets = {normalize(expression.this.this)}
1077
1078    if alias:
1079        targets.add(normalize(alias.this))
1080
1081    for when in expression.expressions:
1082        when.transform(
1083            lambda node: (
1084                exp.column(node.this)
1085                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1086                else node
1087            ),
1088            copy=False,
1089        )
1090
1091    return self.merge_sql(expression)
1092
1093
1094def build_json_extract_path(
1095    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1096) -> t.Callable[[t.List], F]:
1097    def _builder(args: t.List) -> F:
1098        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1099        for arg in args[1:]:
1100            if not isinstance(arg, exp.Literal):
1101                # We use the fallback parser because we can't really transpile non-literals safely
1102                return expr_type.from_arg_list(args)
1103
1104            text = arg.name
1105            if is_int(text):
1106                index = int(text)
1107                segments.append(
1108                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1109                )
1110            else:
1111                segments.append(exp.JSONPathKey(this=text))
1112
1113        # This is done to avoid failing in the expression validator due to the arg count
1114        del args[2:]
1115        return expr_type(
1116            this=seq_get(args, 0),
1117            expression=exp.JSONPath(expressions=segments),
1118            only_json_types=arrow_req_json_type,
1119        )
1120
1121    return _builder
1122
1123
1124def json_extract_segments(
1125    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1126) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1127    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1128        path = expression.expression
1129        if not isinstance(path, exp.JSONPath):
1130            return rename_func(name)(self, expression)
1131
1132        segments = []
1133        for segment in path.expressions:
1134            path = self.sql(segment)
1135            if path:
1136                if isinstance(segment, exp.JSONPathPart) and (
1137                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1138                ):
1139                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1140
1141                segments.append(path)
1142
1143        if op:
1144            return f" {op} ".join([self.sql(expression.this), *segments])
1145        return self.func(name, expression.this, *segments)
1146
1147    return _json_extract_segments
1148
1149
1150def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1151    if isinstance(expression.this, exp.JSONPathWildcard):
1152        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1153
1154    return expression.name
1155
1156
1157def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1158    cond = expression.expression
1159    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1160        alias = cond.expressions[0]
1161        cond = cond.this
1162    elif isinstance(cond, exp.Predicate):
1163        alias = "_u"
1164    else:
1165        self.unsupported("Unsupported filter condition")
1166        return ""
1167
1168    unnest = exp.Unnest(expressions=[expression.this])
1169    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1170    return self.sql(exp.Array(expressions=[filtered]))
1171
1172
1173def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1174    return self.func(
1175        "TO_NUMBER",
1176        expression.this,
1177        expression.args.get("format"),
1178        expression.args.get("nlsparam"),
1179    )
1180
1181
1182def build_default_decimal_type(
1183    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1184) -> t.Callable[[exp.DataType], exp.DataType]:
1185    def _builder(dtype: exp.DataType) -> exp.DataType:
1186        if dtype.expressions or precision is None:
1187            return dtype
1188
1189        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1190        return exp.DataType.build(f"DECIMAL({params})")
1191
1192    return _builder
1193
1194
1195def build_timestamp_from_parts(args: t.List) -> exp.Func:
1196    if len(args) == 2:
1197        # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
1198        # so we parse this into Anonymous for now instead of introducing complexity
1199        return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
1200
1201    return exp.TimestampFromParts.from_arg_list(args)
1202
1203
1204def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1205    return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
logger = <Logger sqlglot (WARNING)>
UNESCAPED_SEQUENCES = {'\\a': '\x07', '\\b': '\x08', '\\f': '\x0c', '\\n': '\n', '\\r': '\r', '\\t': '\t', '\\v': '\x0b', '\\\\': '\\'}
class Dialects(builtins.str, enum.Enum):
41class Dialects(str, Enum):
42    """Dialects supported by SQLGLot."""
43
44    DIALECT = ""
45
46    ATHENA = "athena"
47    BIGQUERY = "bigquery"
48    CLICKHOUSE = "clickhouse"
49    DATABRICKS = "databricks"
50    DORIS = "doris"
51    DRILL = "drill"
52    DUCKDB = "duckdb"
53    HIVE = "hive"
54    MATERIALIZE = "materialize"
55    MYSQL = "mysql"
56    ORACLE = "oracle"
57    POSTGRES = "postgres"
58    PRESTO = "presto"
59    PRQL = "prql"
60    REDSHIFT = "redshift"
61    RISINGWAVE = "risingwave"
62    SNOWFLAKE = "snowflake"
63    SPARK = "spark"
64    SPARK2 = "spark2"
65    SQLITE = "sqlite"
66    STARROCKS = "starrocks"
67    TABLEAU = "tableau"
68    TERADATA = "teradata"
69    TRINO = "trino"
70    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MATERIALIZE = <Dialects.MATERIALIZE: 'materialize'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
RISINGWAVE = <Dialects.RISINGWAVE: 'risingwave'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
73class NormalizationStrategy(str, AutoName):
74    """Specifies the strategy according to which identifiers should be normalized."""
75
76    LOWERCASE = auto()
77    """Unquoted identifiers are lowercased."""
78
79    UPPERCASE = auto()
80    """Unquoted identifiers are uppercased."""
81
82    CASE_SENSITIVE = auto()
83    """Always case-sensitive, regardless of quotes."""
84
85    CASE_INSENSITIVE = auto()
86    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
199class Dialect(metaclass=_Dialect):
200    INDEX_OFFSET = 0
201    """The base index offset for arrays."""
202
203    WEEK_OFFSET = 0
204    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
205
206    UNNEST_COLUMN_ONLY = False
207    """Whether `UNNEST` table aliases are treated as column aliases."""
208
209    ALIAS_POST_TABLESAMPLE = False
210    """Whether the table alias comes after tablesample."""
211
212    TABLESAMPLE_SIZE_IS_PERCENT = False
213    """Whether a size in the table sample clause represents percentage."""
214
215    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
216    """Specifies the strategy according to which identifiers should be normalized."""
217
218    IDENTIFIERS_CAN_START_WITH_DIGIT = False
219    """Whether an unquoted identifier can start with a digit."""
220
221    DPIPE_IS_STRING_CONCAT = True
222    """Whether the DPIPE token (`||`) is a string concatenation operator."""
223
224    STRICT_STRING_CONCAT = False
225    """Whether `CONCAT`'s arguments must be strings."""
226
227    SUPPORTS_USER_DEFINED_TYPES = True
228    """Whether user-defined data types are supported."""
229
230    SUPPORTS_SEMI_ANTI_JOIN = True
231    """Whether `SEMI` or `ANTI` joins are supported."""
232
233    SUPPORTS_COLUMN_JOIN_MARKS = False
234    """Whether the old-style outer join (+) syntax is supported."""
235
236    NORMALIZE_FUNCTIONS: bool | str = "upper"
237    """
238    Determines how function names are going to be normalized.
239    Possible values:
240        "upper" or True: Convert names to uppercase.
241        "lower": Convert names to lowercase.
242        False: Disables function name normalization.
243    """
244
245    LOG_BASE_FIRST: t.Optional[bool] = True
246    """
247    Whether the base comes first in the `LOG` function.
248    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
249    """
250
251    NULL_ORDERING = "nulls_are_small"
252    """
253    Default `NULL` ordering method to use if not explicitly set.
254    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
255    """
256
257    TYPED_DIVISION = False
258    """
259    Whether the behavior of `a / b` depends on the types of `a` and `b`.
260    False means `a / b` is always float division.
261    True means `a / b` is integer division if both `a` and `b` are integers.
262    """
263
264    SAFE_DIVISION = False
265    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
266
267    CONCAT_COALESCE = False
268    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
269
270    HEX_LOWERCASE = False
271    """Whether the `HEX` function returns a lowercase hexadecimal string."""
272
273    DATE_FORMAT = "'%Y-%m-%d'"
274    DATEINT_FORMAT = "'%Y%m%d'"
275    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
276
277    TIME_MAPPING: t.Dict[str, str] = {}
278    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
279
280    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
281    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
282    FORMAT_MAPPING: t.Dict[str, str] = {}
283    """
284    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
285    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
286    """
287
288    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
289    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
290
291    PSEUDOCOLUMNS: t.Set[str] = set()
292    """
293    Columns that are auto-generated by the engine corresponding to this dialect.
294    For example, such columns may be excluded from `SELECT *` queries.
295    """
296
297    PREFER_CTE_ALIAS_COLUMN = False
298    """
299    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
300    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
301    any projection aliases in the subquery.
302
303    For example,
304        WITH y(c) AS (
305            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
306        ) SELECT c FROM y;
307
308        will be rewritten as
309
310        WITH y(c) AS (
311            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
312        ) SELECT c FROM y;
313    """
314
315    # --- Autofilled ---
316
317    tokenizer_class = Tokenizer
318    parser_class = Parser
319    generator_class = Generator
320
321    # A trie of the time_mapping keys
322    TIME_TRIE: t.Dict = {}
323    FORMAT_TRIE: t.Dict = {}
324
325    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
326    INVERSE_TIME_TRIE: t.Dict = {}
327
328    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
329
330    # Delimiters for string literals and identifiers
331    QUOTE_START = "'"
332    QUOTE_END = "'"
333    IDENTIFIER_START = '"'
334    IDENTIFIER_END = '"'
335
336    # Delimiters for bit, hex, byte and unicode literals
337    BIT_START: t.Optional[str] = None
338    BIT_END: t.Optional[str] = None
339    HEX_START: t.Optional[str] = None
340    HEX_END: t.Optional[str] = None
341    BYTE_START: t.Optional[str] = None
342    BYTE_END: t.Optional[str] = None
343    UNICODE_START: t.Optional[str] = None
344    UNICODE_END: t.Optional[str] = None
345
346    # Separator of COPY statement parameters
347    COPY_PARAMS_ARE_CSV = True
348
349    @classmethod
350    def get_or_raise(cls, dialect: DialectType) -> Dialect:
351        """
352        Look up a dialect in the global dialect registry and return it if it exists.
353
354        Args:
355            dialect: The target dialect. If this is a string, it can be optionally followed by
356                additional key-value pairs that are separated by commas and are used to specify
357                dialect settings, such as whether the dialect's identifiers are case-sensitive.
358
359        Example:
360            >>> dialect = dialect_class = get_or_raise("duckdb")
361            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
362
363        Returns:
364            The corresponding Dialect instance.
365        """
366
367        if not dialect:
368            return cls()
369        if isinstance(dialect, _Dialect):
370            return dialect()
371        if isinstance(dialect, Dialect):
372            return dialect
373        if isinstance(dialect, str):
374            try:
375                dialect_name, *kv_pairs = dialect.split(",")
376                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
377            except ValueError:
378                raise ValueError(
379                    f"Invalid dialect format: '{dialect}'. "
380                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
381                )
382
383            result = cls.get(dialect_name.strip())
384            if not result:
385                from difflib import get_close_matches
386
387                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
388                if similar:
389                    similar = f" Did you mean {similar}?"
390
391                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
392
393            return result(**kwargs)
394
395        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
396
397    @classmethod
398    def format_time(
399        cls, expression: t.Optional[str | exp.Expression]
400    ) -> t.Optional[exp.Expression]:
401        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
402        if isinstance(expression, str):
403            return exp.Literal.string(
404                # the time formats are quoted
405                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
406            )
407
408        if expression and expression.is_string:
409            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
410
411        return expression
412
413    def __init__(self, **kwargs) -> None:
414        normalization_strategy = kwargs.get("normalization_strategy")
415
416        if normalization_strategy is None:
417            self.normalization_strategy = self.NORMALIZATION_STRATEGY
418        else:
419            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
420
421    def __eq__(self, other: t.Any) -> bool:
422        # Does not currently take dialect state into account
423        return type(self) == other
424
425    def __hash__(self) -> int:
426        # Does not currently take dialect state into account
427        return hash(type(self))
428
429    def normalize_identifier(self, expression: E) -> E:
430        """
431        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
432
433        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
434        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
435        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
436        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
437
438        There are also dialects like Spark, which are case-insensitive even when quotes are
439        present, and dialects like MySQL, whose resolution rules match those employed by the
440        underlying operating system, for example they may always be case-sensitive in Linux.
441
442        Finally, the normalization behavior of some engines can even be controlled through flags,
443        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
444
445        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
446        that it can analyze queries in the optimizer and successfully capture their semantics.
447        """
448        if (
449            isinstance(expression, exp.Identifier)
450            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
451            and (
452                not expression.quoted
453                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
454            )
455        ):
456            expression.set(
457                "this",
458                (
459                    expression.this.upper()
460                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
461                    else expression.this.lower()
462                ),
463            )
464
465        return expression
466
467    def case_sensitive(self, text: str) -> bool:
468        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
469        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
470            return False
471
472        unsafe = (
473            str.islower
474            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
475            else str.isupper
476        )
477        return any(unsafe(char) for char in text)
478
479    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
480        """Checks if text can be identified given an identify option.
481
482        Args:
483            text: The text to check.
484            identify:
485                `"always"` or `True`: Always returns `True`.
486                `"safe"`: Only returns `True` if the identifier is case-insensitive.
487
488        Returns:
489            Whether the given text can be identified.
490        """
491        if identify is True or identify == "always":
492            return True
493
494        if identify == "safe":
495            return not self.case_sensitive(text)
496
497        return False
498
499    def quote_identifier(self, expression: E, identify: bool = True) -> E:
500        """
501        Adds quotes to a given identifier.
502
503        Args:
504            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
505            identify: If set to `False`, the quotes will only be added if the identifier is deemed
506                "unsafe", with respect to its characters and this dialect's normalization strategy.
507        """
508        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
509            name = expression.this
510            expression.set(
511                "quoted",
512                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
513            )
514
515        return expression
516
517    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
518        if isinstance(path, exp.Literal):
519            path_text = path.name
520            if path.is_number:
521                path_text = f"[{path_text}]"
522
523            try:
524                return parse_json_path(path_text)
525            except ParseError as e:
526                logger.warning(f"Invalid JSON path syntax. {str(e)}")
527
528        return path
529
530    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
531        return self.parser(**opts).parse(self.tokenize(sql), sql)
532
533    def parse_into(
534        self, expression_type: exp.IntoType, sql: str, **opts
535    ) -> t.List[t.Optional[exp.Expression]]:
536        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
537
538    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
539        return self.generator(**opts).generate(expression, copy=copy)
540
541    def transpile(self, sql: str, **opts) -> t.List[str]:
542        return [
543            self.generate(expression, copy=False, **opts) if expression else ""
544            for expression in self.parse(sql)
545        ]
546
547    def tokenize(self, sql: str) -> t.List[Token]:
548        return self.tokenizer.tokenize(sql)
549
550    @property
551    def tokenizer(self) -> Tokenizer:
552        if not hasattr(self, "_tokenizer"):
553            self._tokenizer = self.tokenizer_class(dialect=self)
554        return self._tokenizer
555
556    def parser(self, **opts) -> Parser:
557        return self.parser_class(dialect=self, **opts)
558
559    def generator(self, **opts) -> Generator:
560        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
413    def __init__(self, **kwargs) -> None:
414        normalization_strategy = kwargs.get("normalization_strategy")
415
416        if normalization_strategy is None:
417            self.normalization_strategy = self.NORMALIZATION_STRATEGY
418        else:
419            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

SUPPORTS_COLUMN_JOIN_MARKS = False

Whether the old-style outer join (+) syntax is supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

HEX_LOWERCASE = False

Whether the HEX function returns a lowercase hexadecimal string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
COPY_PARAMS_ARE_CSV = True
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
349    @classmethod
350    def get_or_raise(cls, dialect: DialectType) -> Dialect:
351        """
352        Look up a dialect in the global dialect registry and return it if it exists.
353
354        Args:
355            dialect: The target dialect. If this is a string, it can be optionally followed by
356                additional key-value pairs that are separated by commas and are used to specify
357                dialect settings, such as whether the dialect's identifiers are case-sensitive.
358
359        Example:
360            >>> dialect = dialect_class = get_or_raise("duckdb")
361            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
362
363        Returns:
364            The corresponding Dialect instance.
365        """
366
367        if not dialect:
368            return cls()
369        if isinstance(dialect, _Dialect):
370            return dialect()
371        if isinstance(dialect, Dialect):
372            return dialect
373        if isinstance(dialect, str):
374            try:
375                dialect_name, *kv_pairs = dialect.split(",")
376                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
377            except ValueError:
378                raise ValueError(
379                    f"Invalid dialect format: '{dialect}'. "
380                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
381                )
382
383            result = cls.get(dialect_name.strip())
384            if not result:
385                from difflib import get_close_matches
386
387                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
388                if similar:
389                    similar = f" Did you mean {similar}?"
390
391                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
392
393            return result(**kwargs)
394
395        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
397    @classmethod
398    def format_time(
399        cls, expression: t.Optional[str | exp.Expression]
400    ) -> t.Optional[exp.Expression]:
401        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
402        if isinstance(expression, str):
403            return exp.Literal.string(
404                # the time formats are quoted
405                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
406            )
407
408        if expression and expression.is_string:
409            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
410
411        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
429    def normalize_identifier(self, expression: E) -> E:
430        """
431        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
432
433        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
434        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
435        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
436        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
437
438        There are also dialects like Spark, which are case-insensitive even when quotes are
439        present, and dialects like MySQL, whose resolution rules match those employed by the
440        underlying operating system, for example they may always be case-sensitive in Linux.
441
442        Finally, the normalization behavior of some engines can even be controlled through flags,
443        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
444
445        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
446        that it can analyze queries in the optimizer and successfully capture their semantics.
447        """
448        if (
449            isinstance(expression, exp.Identifier)
450            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
451            and (
452                not expression.quoted
453                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
454            )
455        ):
456            expression.set(
457                "this",
458                (
459                    expression.this.upper()
460                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
461                    else expression.this.lower()
462                ),
463            )
464
465        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
467    def case_sensitive(self, text: str) -> bool:
468        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
469        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
470            return False
471
472        unsafe = (
473            str.islower
474            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
475            else str.isupper
476        )
477        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
479    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
480        """Checks if text can be identified given an identify option.
481
482        Args:
483            text: The text to check.
484            identify:
485                `"always"` or `True`: Always returns `True`.
486                `"safe"`: Only returns `True` if the identifier is case-insensitive.
487
488        Returns:
489            Whether the given text can be identified.
490        """
491        if identify is True or identify == "always":
492            return True
493
494        if identify == "safe":
495            return not self.case_sensitive(text)
496
497        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
499    def quote_identifier(self, expression: E, identify: bool = True) -> E:
500        """
501        Adds quotes to a given identifier.
502
503        Args:
504            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
505            identify: If set to `False`, the quotes will only be added if the identifier is deemed
506                "unsafe", with respect to its characters and this dialect's normalization strategy.
507        """
508        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
509            name = expression.this
510            expression.set(
511                "quoted",
512                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
513            )
514
515        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
517    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
518        if isinstance(path, exp.Literal):
519            path_text = path.name
520            if path.is_number:
521                path_text = f"[{path_text}]"
522
523            try:
524                return parse_json_path(path_text)
525            except ParseError as e:
526                logger.warning(f"Invalid JSON path syntax. {str(e)}")
527
528        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
530    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
531        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
533    def parse_into(
534        self, expression_type: exp.IntoType, sql: str, **opts
535    ) -> t.List[t.Optional[exp.Expression]]:
536        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
538    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
539        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
541    def transpile(self, sql: str, **opts) -> t.List[str]:
542        return [
543            self.generate(expression, copy=False, **opts) if expression else ""
544            for expression in self.parse(sql)
545        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
547    def tokenize(self, sql: str) -> t.List[Token]:
548        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
550    @property
551    def tokenizer(self) -> Tokenizer:
552        if not hasattr(self, "_tokenizer"):
553            self._tokenizer = self.tokenizer_class(dialect=self)
554        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
556    def parser(self, **opts) -> Parser:
557        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
559    def generator(self, **opts) -> Generator:
560        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
566def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
567    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
570def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
571    if expression.args.get("accuracy"):
572        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
573    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
576def if_sql(
577    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
578) -> t.Callable[[Generator, exp.If], str]:
579    def _if_sql(self: Generator, expression: exp.If) -> str:
580        return self.func(
581            name,
582            expression.this,
583            expression.args.get("true"),
584            expression.args.get("false") or false_value,
585        )
586
587    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
590def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
591    this = expression.this
592    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
593        this.replace(exp.cast(this, exp.DataType.Type.JSON))
594
595    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
598def inline_array_sql(self: Generator, expression: exp.Array) -> str:
599    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
602def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
603    elem = seq_get(expression.expressions, 0)
604    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
605        return self.func("ARRAY", elem)
606    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
609def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
610    return self.like_sql(
611        exp.Like(
612            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
613        )
614    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
617def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
618    zone = self.sql(expression, "this")
619    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
622def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
623    if expression.args.get("recursive"):
624        self.unsupported("Recursive CTEs are unsupported")
625        expression.args["recursive"] = False
626    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
629def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
630    n = self.sql(expression, "this")
631    d = self.sql(expression, "expression")
632    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
635def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
636    self.unsupported("TABLESAMPLE unsupported")
637    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
640def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
641    self.unsupported("PIVOT unsupported")
642    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
645def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
646    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
649def no_comment_column_constraint_sql(
650    self: Generator, expression: exp.CommentColumnConstraint
651) -> str:
652    self.unsupported("CommentColumnConstraint unsupported")
653    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
656def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
657    self.unsupported("MAP_FROM_ENTRIES unsupported")
658    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
661def str_position_sql(
662    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
663) -> str:
664    this = self.sql(expression, "this")
665    substr = self.sql(expression, "substr")
666    position = self.sql(expression, "position")
667    instance = expression.args.get("instance") if generate_instance else None
668    position_offset = ""
669
670    if position:
671        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
672        this = self.func("SUBSTR", this, position)
673        position_offset = f" + {position} - 1"
674
675    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
678def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
679    return (
680        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
681    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
684def var_map_sql(
685    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
686) -> str:
687    keys = expression.args["keys"]
688    values = expression.args["values"]
689
690    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
691        self.unsupported("Cannot convert array columns into map.")
692        return self.func(map_func_name, keys, values)
693
694    args = []
695    for key, value in zip(keys.expressions, values.expressions):
696        args.append(self.sql(key))
697        args.append(self.sql(value))
698
699    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
702def build_formatted_time(
703    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
704) -> t.Callable[[t.List], E]:
705    """Helper used for time expressions.
706
707    Args:
708        exp_class: the expression class to instantiate.
709        dialect: target sql dialect.
710        default: the default format, True being time.
711
712    Returns:
713        A callable that can be used to return the appropriately formatted time expression.
714    """
715
716    def _builder(args: t.List):
717        return exp_class(
718            this=seq_get(args, 0),
719            format=Dialect[dialect].format_time(
720                seq_get(args, 1)
721                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
722            ),
723        )
724
725    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
728def time_format(
729    dialect: DialectType = None,
730) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
731    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
732        """
733        Returns the time format for a given expression, unless it's equivalent
734        to the default time format of the dialect of interest.
735        """
736        time_format = self.format_time(expression)
737        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
738
739    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
742def build_date_delta(
743    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
744) -> t.Callable[[t.List], E]:
745    def _builder(args: t.List) -> E:
746        unit_based = len(args) == 3
747        this = args[2] if unit_based else seq_get(args, 0)
748        unit = args[0] if unit_based else exp.Literal.string("DAY")
749        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
750        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
751
752    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
755def build_date_delta_with_interval(
756    expression_class: t.Type[E],
757) -> t.Callable[[t.List], t.Optional[E]]:
758    def _builder(args: t.List) -> t.Optional[E]:
759        if len(args) < 2:
760            return None
761
762        interval = args[1]
763
764        if not isinstance(interval, exp.Interval):
765            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
766
767        expression = interval.this
768        if expression and expression.is_string:
769            expression = exp.Literal.number(expression.this)
770
771        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
772
773    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
776def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
777    unit = seq_get(args, 0)
778    this = seq_get(args, 1)
779
780    if isinstance(this, exp.Cast) and this.is_type("date"):
781        return exp.DateTrunc(unit=unit, this=this)
782    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
785def date_add_interval_sql(
786    data_type: str, kind: str
787) -> t.Callable[[Generator, exp.Expression], str]:
788    def func(self: Generator, expression: exp.Expression) -> str:
789        this = self.sql(expression, "this")
790        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
791        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
792
793    return func
def timestamptrunc_sql( zone: bool = False) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.TimestampTrunc], str]:
796def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
797    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
798        args = [unit_to_str(expression), expression.this]
799        if zone:
800            args.append(expression.args.get("zone"))
801        return self.func("DATE_TRUNC", *args)
802
803    return _timestamptrunc_sql
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
806def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
807    if not expression.expression:
808        from sqlglot.optimizer.annotate_types import annotate_types
809
810        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
811        return self.sql(exp.cast(expression.this, target_type))
812    if expression.text("expression").lower() in TIMEZONES:
813        return self.sql(
814            exp.AtTimeZone(
815                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
816                zone=expression.expression,
817            )
818        )
819    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
822def locate_to_strposition(args: t.List) -> exp.Expression:
823    return exp.StrPosition(
824        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
825    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
828def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
829    return self.func(
830        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
831    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
834def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
835    return self.sql(
836        exp.Substring(
837            this=expression.this, start=exp.Literal.number(1), length=expression.expression
838        )
839    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
842def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
843    return self.sql(
844        exp.Substring(
845            this=expression.this,
846            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
847        )
848    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
851def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
852    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
855def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
856    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
860def encode_decode_sql(
861    self: Generator, expression: exp.Expression, name: str, replace: bool = True
862) -> str:
863    charset = expression.args.get("charset")
864    if charset and charset.name.lower() != "utf-8":
865        self.unsupported(f"Expected utf-8 character set, got {charset}.")
866
867    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
870def min_or_least(self: Generator, expression: exp.Min) -> str:
871    name = "LEAST" if expression.expressions else "MIN"
872    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
875def max_or_greatest(self: Generator, expression: exp.Max) -> str:
876    name = "GREATEST" if expression.expressions else "MAX"
877    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
880def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
881    cond = expression.this
882
883    if isinstance(expression.this, exp.Distinct):
884        cond = expression.this.expressions[0]
885        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
886
887    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
890def trim_sql(self: Generator, expression: exp.Trim) -> str:
891    target = self.sql(expression, "this")
892    trim_type = self.sql(expression, "position")
893    remove_chars = self.sql(expression, "expression")
894    collation = self.sql(expression, "collation")
895
896    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
897    if not remove_chars and not collation:
898        return self.trim_sql(expression)
899
900    trim_type = f"{trim_type} " if trim_type else ""
901    remove_chars = f"{remove_chars} " if remove_chars else ""
902    from_part = "FROM " if trim_type or remove_chars else ""
903    collation = f" COLLATE {collation}" if collation else ""
904    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
907def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
908    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
911def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
912    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
915def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
916    delim, *rest_args = expression.expressions
917    return self.sql(
918        reduce(
919            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
920            rest_args,
921        )
922    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
925def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
926    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
927    if bad_args:
928        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
929
930    return self.func(
931        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
932    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
935def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
936    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
937    if bad_args:
938        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
939
940    return self.func(
941        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
942    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
945def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
946    names = []
947    for agg in aggregations:
948        if isinstance(agg, exp.Alias):
949            names.append(agg.alias)
950        else:
951            """
952            This case corresponds to aggregations without aliases being used as suffixes
953            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
954            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
955            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
956            """
957            agg_all_unquoted = agg.transform(
958                lambda node: (
959                    exp.Identifier(this=node.name, quoted=False)
960                    if isinstance(node, exp.Identifier)
961                    else node
962                )
963            )
964            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
965
966    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
969def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
970    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
974def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
975    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
978def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
979    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
982def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
983    a = self.sql(expression.left)
984    b = self.sql(expression.right)
985    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
988def is_parse_json(expression: exp.Expression) -> bool:
989    return isinstance(expression, exp.ParseJSON) or (
990        isinstance(expression, exp.Cast) and expression.is_type("json")
991    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
994def isnull_to_is_null(args: t.List) -> exp.Expression:
995    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
 998def generatedasidentitycolumnconstraint_sql(
 999    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
1000) -> str:
1001    start = self.sql(expression, "start") or "1"
1002    increment = self.sql(expression, "increment") or "1"
1003    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
1006def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
1007    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
1008        if expression.args.get("count"):
1009            self.unsupported(f"Only two arguments are supported in function {name}.")
1010
1011        return self.func(name, expression.this, expression.expression)
1012
1013    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
1016def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1017    this = expression.this.copy()
1018
1019    return_type = expression.return_type
1020    if return_type.is_type(exp.DataType.Type.DATE):
1021        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1022        # can truncate timestamp strings, because some dialects can't cast them to DATE
1023        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1024
1025    expression.this.replace(exp.cast(this, return_type))
1026    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1029def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1030    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1031        if cast and isinstance(expression, exp.TsOrDsAdd):
1032            expression = ts_or_ds_add_cast(expression)
1033
1034        return self.func(
1035            name,
1036            unit_to_var(expression),
1037            expression.expression,
1038            expression.this,
1039        )
1040
1041    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1044def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1045    unit = expression.args.get("unit")
1046
1047    if isinstance(unit, exp.Placeholder):
1048        return unit
1049    if unit:
1050        return exp.Literal.string(unit.name)
1051    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1054def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1055    unit = expression.args.get("unit")
1056
1057    if isinstance(unit, (exp.Var, exp.Placeholder)):
1058        return unit
1059    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1062def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1063    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1064    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1065    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1066
1067    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1070def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1071    """Remove table refs from columns in when statements."""
1072    alias = expression.this.args.get("alias")
1073
1074    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1075        return self.dialect.normalize_identifier(identifier).name if identifier else None
1076
1077    targets = {normalize(expression.this.this)}
1078
1079    if alias:
1080        targets.add(normalize(alias.this))
1081
1082    for when in expression.expressions:
1083        when.transform(
1084            lambda node: (
1085                exp.column(node.this)
1086                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1087                else node
1088            ),
1089            copy=False,
1090        )
1091
1092    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1095def build_json_extract_path(
1096    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1097) -> t.Callable[[t.List], F]:
1098    def _builder(args: t.List) -> F:
1099        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1100        for arg in args[1:]:
1101            if not isinstance(arg, exp.Literal):
1102                # We use the fallback parser because we can't really transpile non-literals safely
1103                return expr_type.from_arg_list(args)
1104
1105            text = arg.name
1106            if is_int(text):
1107                index = int(text)
1108                segments.append(
1109                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1110                )
1111            else:
1112                segments.append(exp.JSONPathKey(this=text))
1113
1114        # This is done to avoid failing in the expression validator due to the arg count
1115        del args[2:]
1116        return expr_type(
1117            this=seq_get(args, 0),
1118            expression=exp.JSONPath(expressions=segments),
1119            only_json_types=arrow_req_json_type,
1120        )
1121
1122    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1125def json_extract_segments(
1126    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1127) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1128    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1129        path = expression.expression
1130        if not isinstance(path, exp.JSONPath):
1131            return rename_func(name)(self, expression)
1132
1133        segments = []
1134        for segment in path.expressions:
1135            path = self.sql(segment)
1136            if path:
1137                if isinstance(segment, exp.JSONPathPart) and (
1138                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1139                ):
1140                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1141
1142                segments.append(path)
1143
1144        if op:
1145            return f" {op} ".join([self.sql(expression.this), *segments])
1146        return self.func(name, expression.this, *segments)
1147
1148    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1151def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1152    if isinstance(expression.this, exp.JSONPathWildcard):
1153        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1154
1155    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1158def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1159    cond = expression.expression
1160    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1161        alias = cond.expressions[0]
1162        cond = cond.this
1163    elif isinstance(cond, exp.Predicate):
1164        alias = "_u"
1165    else:
1166        self.unsupported("Unsupported filter condition")
1167        return ""
1168
1169    unnest = exp.Unnest(expressions=[expression.this])
1170    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1171    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ToNumber) -> str:
1174def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1175    return self.func(
1176        "TO_NUMBER",
1177        expression.this,
1178        expression.args.get("format"),
1179        expression.args.get("nlsparam"),
1180    )
def build_default_decimal_type( precision: Optional[int] = None, scale: Optional[int] = None) -> Callable[[sqlglot.expressions.DataType], sqlglot.expressions.DataType]:
1183def build_default_decimal_type(
1184    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1185) -> t.Callable[[exp.DataType], exp.DataType]:
1186    def _builder(dtype: exp.DataType) -> exp.DataType:
1187        if dtype.expressions or precision is None:
1188            return dtype
1189
1190        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1191        return exp.DataType.build(f"DECIMAL({params})")
1192
1193    return _builder
def build_timestamp_from_parts(args: List) -> sqlglot.expressions.Func:
1196def build_timestamp_from_parts(args: t.List) -> exp.Func:
1197    if len(args) == 2:
1198        # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
1199        # so we parse this into Anonymous for now instead of introducing complexity
1200        return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
1201
1202    return exp.TimestampFromParts.from_arg_list(args)
def sha256_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SHA2) -> str:
1205def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1206    return self.func(f"SHA{expression.text('length') or '256'}", expression.this)