sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MATERIALIZE = "materialize" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 RISINGWAVE = "risingwave" 61 SNOWFLAKE = "snowflake" 62 SPARK = "spark" 63 SPARK2 = "spark2" 64 SQLITE = "sqlite" 65 STARROCKS = "starrocks" 66 TABLEAU = "tableau" 67 TERADATA = "teradata" 68 TRINO = "trino" 69 TSQL = "tsql" 70 71 72class NormalizationStrategy(str, AutoName): 73 """Specifies the strategy according to which identifiers should be normalized.""" 74 75 LOWERCASE = auto() 76 """Unquoted identifiers are lowercased.""" 77 78 UPPERCASE = auto() 79 """Unquoted identifiers are uppercased.""" 80 81 CASE_SENSITIVE = auto() 82 """Always case-sensitive, regardless of quotes.""" 83 84 CASE_INSENSITIVE = auto() 85 """Always case-insensitive, regardless of quotes.""" 86 87 88class _Dialect(type): 89 classes: t.Dict[str, t.Type[Dialect]] = {} 90 91 def __eq__(cls, other: t.Any) -> bool: 92 if cls is other: 93 return True 94 if isinstance(other, str): 95 return cls is cls.get(other) 96 if isinstance(other, Dialect): 97 return cls is type(other) 98 99 return False 100 101 def __hash__(cls) -> int: 102 return hash(cls.__name__.lower()) 103 104 @classmethod 105 def __getitem__(cls, key: str) -> t.Type[Dialect]: 106 return cls.classes[key] 107 108 @classmethod 109 def get( 110 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 111 ) -> t.Optional[t.Type[Dialect]]: 112 return cls.classes.get(key, default) 113 114 def __new__(cls, clsname, bases, attrs): 115 klass = super().__new__(cls, clsname, bases, attrs) 116 enum = Dialects.__members__.get(clsname.upper()) 117 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 118 119 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 120 klass.FORMAT_TRIE = ( 121 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 122 ) 123 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 124 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 125 126 base = seq_get(bases, 0) 127 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 128 base_parser = (getattr(base, "parser_class", Parser),) 129 base_generator = (getattr(base, "generator_class", Generator),) 130 131 klass.tokenizer_class = klass.__dict__.get( 132 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 133 ) 134 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 135 klass.generator_class = klass.__dict__.get( 136 "Generator", type("Generator", base_generator, {}) 137 ) 138 139 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 140 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 141 klass.tokenizer_class._IDENTIFIERS.items() 142 )[0] 143 144 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 145 return next( 146 ( 147 (s, e) 148 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 149 if t == token_type 150 ), 151 (None, None), 152 ) 153 154 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 155 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 156 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 157 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 158 159 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 160 klass.UNESCAPED_SEQUENCES = { 161 **UNESCAPED_SEQUENCES, 162 **klass.UNESCAPED_SEQUENCES, 163 } 164 165 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 166 167 if enum not in ("", "bigquery"): 168 klass.generator_class.SELECT_KINDS = () 169 170 if enum not in ("", "athena", "presto", "trino"): 171 klass.generator_class.TRY_SUPPORTED = False 172 klass.generator_class.SUPPORTS_UESCAPE = False 173 174 if enum not in ("", "databricks", "hive", "spark", "spark2"): 175 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 176 for modifier in ("cluster", "distribute", "sort"): 177 modifier_transforms.pop(modifier, None) 178 179 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 180 181 if enum not in ("", "doris", "mysql"): 182 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 183 TokenType.STRAIGHT_JOIN, 184 } 185 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 186 TokenType.STRAIGHT_JOIN, 187 } 188 189 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 190 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 191 TokenType.ANTI, 192 TokenType.SEMI, 193 } 194 195 return klass 196 197 198class Dialect(metaclass=_Dialect): 199 INDEX_OFFSET = 0 200 """The base index offset for arrays.""" 201 202 WEEK_OFFSET = 0 203 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 204 205 UNNEST_COLUMN_ONLY = False 206 """Whether `UNNEST` table aliases are treated as column aliases.""" 207 208 ALIAS_POST_TABLESAMPLE = False 209 """Whether the table alias comes after tablesample.""" 210 211 TABLESAMPLE_SIZE_IS_PERCENT = False 212 """Whether a size in the table sample clause represents percentage.""" 213 214 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 215 """Specifies the strategy according to which identifiers should be normalized.""" 216 217 IDENTIFIERS_CAN_START_WITH_DIGIT = False 218 """Whether an unquoted identifier can start with a digit.""" 219 220 DPIPE_IS_STRING_CONCAT = True 221 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 222 223 STRICT_STRING_CONCAT = False 224 """Whether `CONCAT`'s arguments must be strings.""" 225 226 SUPPORTS_USER_DEFINED_TYPES = True 227 """Whether user-defined data types are supported.""" 228 229 SUPPORTS_SEMI_ANTI_JOIN = True 230 """Whether `SEMI` or `ANTI` joins are supported.""" 231 232 SUPPORTS_COLUMN_JOIN_MARKS = False 233 """Whether the old-style outer join (+) syntax is supported.""" 234 235 NORMALIZE_FUNCTIONS: bool | str = "upper" 236 """ 237 Determines how function names are going to be normalized. 238 Possible values: 239 "upper" or True: Convert names to uppercase. 240 "lower": Convert names to lowercase. 241 False: Disables function name normalization. 242 """ 243 244 LOG_BASE_FIRST: t.Optional[bool] = True 245 """ 246 Whether the base comes first in the `LOG` function. 247 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 248 """ 249 250 NULL_ORDERING = "nulls_are_small" 251 """ 252 Default `NULL` ordering method to use if not explicitly set. 253 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 254 """ 255 256 TYPED_DIVISION = False 257 """ 258 Whether the behavior of `a / b` depends on the types of `a` and `b`. 259 False means `a / b` is always float division. 260 True means `a / b` is integer division if both `a` and `b` are integers. 261 """ 262 263 SAFE_DIVISION = False 264 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 265 266 CONCAT_COALESCE = False 267 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 268 269 HEX_LOWERCASE = False 270 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 271 272 DATE_FORMAT = "'%Y-%m-%d'" 273 DATEINT_FORMAT = "'%Y%m%d'" 274 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 275 276 TIME_MAPPING: t.Dict[str, str] = {} 277 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 278 279 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 280 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 281 FORMAT_MAPPING: t.Dict[str, str] = {} 282 """ 283 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 284 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 285 """ 286 287 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 288 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 289 290 PSEUDOCOLUMNS: t.Set[str] = set() 291 """ 292 Columns that are auto-generated by the engine corresponding to this dialect. 293 For example, such columns may be excluded from `SELECT *` queries. 294 """ 295 296 PREFER_CTE_ALIAS_COLUMN = False 297 """ 298 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 299 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 300 any projection aliases in the subquery. 301 302 For example, 303 WITH y(c) AS ( 304 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 305 ) SELECT c FROM y; 306 307 will be rewritten as 308 309 WITH y(c) AS ( 310 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 311 ) SELECT c FROM y; 312 """ 313 314 # --- Autofilled --- 315 316 tokenizer_class = Tokenizer 317 parser_class = Parser 318 generator_class = Generator 319 320 # A trie of the time_mapping keys 321 TIME_TRIE: t.Dict = {} 322 FORMAT_TRIE: t.Dict = {} 323 324 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 325 INVERSE_TIME_TRIE: t.Dict = {} 326 327 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 328 329 # Delimiters for string literals and identifiers 330 QUOTE_START = "'" 331 QUOTE_END = "'" 332 IDENTIFIER_START = '"' 333 IDENTIFIER_END = '"' 334 335 # Delimiters for bit, hex, byte and unicode literals 336 BIT_START: t.Optional[str] = None 337 BIT_END: t.Optional[str] = None 338 HEX_START: t.Optional[str] = None 339 HEX_END: t.Optional[str] = None 340 BYTE_START: t.Optional[str] = None 341 BYTE_END: t.Optional[str] = None 342 UNICODE_START: t.Optional[str] = None 343 UNICODE_END: t.Optional[str] = None 344 345 # Separator of COPY statement parameters 346 COPY_PARAMS_ARE_CSV = True 347 348 @classmethod 349 def get_or_raise(cls, dialect: DialectType) -> Dialect: 350 """ 351 Look up a dialect in the global dialect registry and return it if it exists. 352 353 Args: 354 dialect: The target dialect. If this is a string, it can be optionally followed by 355 additional key-value pairs that are separated by commas and are used to specify 356 dialect settings, such as whether the dialect's identifiers are case-sensitive. 357 358 Example: 359 >>> dialect = dialect_class = get_or_raise("duckdb") 360 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 361 362 Returns: 363 The corresponding Dialect instance. 364 """ 365 366 if not dialect: 367 return cls() 368 if isinstance(dialect, _Dialect): 369 return dialect() 370 if isinstance(dialect, Dialect): 371 return dialect 372 if isinstance(dialect, str): 373 try: 374 dialect_name, *kv_pairs = dialect.split(",") 375 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 376 except ValueError: 377 raise ValueError( 378 f"Invalid dialect format: '{dialect}'. " 379 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 380 ) 381 382 result = cls.get(dialect_name.strip()) 383 if not result: 384 from difflib import get_close_matches 385 386 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 387 if similar: 388 similar = f" Did you mean {similar}?" 389 390 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 391 392 return result(**kwargs) 393 394 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 395 396 @classmethod 397 def format_time( 398 cls, expression: t.Optional[str | exp.Expression] 399 ) -> t.Optional[exp.Expression]: 400 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 401 if isinstance(expression, str): 402 return exp.Literal.string( 403 # the time formats are quoted 404 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 405 ) 406 407 if expression and expression.is_string: 408 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 409 410 return expression 411 412 def __init__(self, **kwargs) -> None: 413 normalization_strategy = kwargs.get("normalization_strategy") 414 415 if normalization_strategy is None: 416 self.normalization_strategy = self.NORMALIZATION_STRATEGY 417 else: 418 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 419 420 def __eq__(self, other: t.Any) -> bool: 421 # Does not currently take dialect state into account 422 return type(self) == other 423 424 def __hash__(self) -> int: 425 # Does not currently take dialect state into account 426 return hash(type(self)) 427 428 def normalize_identifier(self, expression: E) -> E: 429 """ 430 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 431 432 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 433 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 434 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 435 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 436 437 There are also dialects like Spark, which are case-insensitive even when quotes are 438 present, and dialects like MySQL, whose resolution rules match those employed by the 439 underlying operating system, for example they may always be case-sensitive in Linux. 440 441 Finally, the normalization behavior of some engines can even be controlled through flags, 442 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 443 444 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 445 that it can analyze queries in the optimizer and successfully capture their semantics. 446 """ 447 if ( 448 isinstance(expression, exp.Identifier) 449 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 450 and ( 451 not expression.quoted 452 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 453 ) 454 ): 455 expression.set( 456 "this", 457 ( 458 expression.this.upper() 459 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 460 else expression.this.lower() 461 ), 462 ) 463 464 return expression 465 466 def case_sensitive(self, text: str) -> bool: 467 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 468 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 469 return False 470 471 unsafe = ( 472 str.islower 473 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 474 else str.isupper 475 ) 476 return any(unsafe(char) for char in text) 477 478 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 479 """Checks if text can be identified given an identify option. 480 481 Args: 482 text: The text to check. 483 identify: 484 `"always"` or `True`: Always returns `True`. 485 `"safe"`: Only returns `True` if the identifier is case-insensitive. 486 487 Returns: 488 Whether the given text can be identified. 489 """ 490 if identify is True or identify == "always": 491 return True 492 493 if identify == "safe": 494 return not self.case_sensitive(text) 495 496 return False 497 498 def quote_identifier(self, expression: E, identify: bool = True) -> E: 499 """ 500 Adds quotes to a given identifier. 501 502 Args: 503 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 504 identify: If set to `False`, the quotes will only be added if the identifier is deemed 505 "unsafe", with respect to its characters and this dialect's normalization strategy. 506 """ 507 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 508 name = expression.this 509 expression.set( 510 "quoted", 511 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 512 ) 513 514 return expression 515 516 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 517 if isinstance(path, exp.Literal): 518 path_text = path.name 519 if path.is_number: 520 path_text = f"[{path_text}]" 521 522 try: 523 return parse_json_path(path_text) 524 except ParseError as e: 525 logger.warning(f"Invalid JSON path syntax. {str(e)}") 526 527 return path 528 529 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 530 return self.parser(**opts).parse(self.tokenize(sql), sql) 531 532 def parse_into( 533 self, expression_type: exp.IntoType, sql: str, **opts 534 ) -> t.List[t.Optional[exp.Expression]]: 535 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 536 537 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 538 return self.generator(**opts).generate(expression, copy=copy) 539 540 def transpile(self, sql: str, **opts) -> t.List[str]: 541 return [ 542 self.generate(expression, copy=False, **opts) if expression else "" 543 for expression in self.parse(sql) 544 ] 545 546 def tokenize(self, sql: str) -> t.List[Token]: 547 return self.tokenizer.tokenize(sql) 548 549 @property 550 def tokenizer(self) -> Tokenizer: 551 if not hasattr(self, "_tokenizer"): 552 self._tokenizer = self.tokenizer_class(dialect=self) 553 return self._tokenizer 554 555 def parser(self, **opts) -> Parser: 556 return self.parser_class(dialect=self, **opts) 557 558 def generator(self, **opts) -> Generator: 559 return self.generator_class(dialect=self, **opts) 560 561 562DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 563 564 565def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 566 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 567 568 569def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 570 if expression.args.get("accuracy"): 571 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 572 return self.func("APPROX_COUNT_DISTINCT", expression.this) 573 574 575def if_sql( 576 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 577) -> t.Callable[[Generator, exp.If], str]: 578 def _if_sql(self: Generator, expression: exp.If) -> str: 579 return self.func( 580 name, 581 expression.this, 582 expression.args.get("true"), 583 expression.args.get("false") or false_value, 584 ) 585 586 return _if_sql 587 588 589def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 590 this = expression.this 591 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 592 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 593 594 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 595 596 597def inline_array_sql(self: Generator, expression: exp.Array) -> str: 598 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 599 600 601def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 602 elem = seq_get(expression.expressions, 0) 603 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 604 return self.func("ARRAY", elem) 605 return inline_array_sql(self, expression) 606 607 608def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 609 return self.like_sql( 610 exp.Like( 611 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 612 ) 613 ) 614 615 616def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 617 zone = self.sql(expression, "this") 618 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 619 620 621def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 622 if expression.args.get("recursive"): 623 self.unsupported("Recursive CTEs are unsupported") 624 expression.args["recursive"] = False 625 return self.with_sql(expression) 626 627 628def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 629 n = self.sql(expression, "this") 630 d = self.sql(expression, "expression") 631 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 632 633 634def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 635 self.unsupported("TABLESAMPLE unsupported") 636 return self.sql(expression.this) 637 638 639def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 640 self.unsupported("PIVOT unsupported") 641 return "" 642 643 644def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 645 return self.cast_sql(expression) 646 647 648def no_comment_column_constraint_sql( 649 self: Generator, expression: exp.CommentColumnConstraint 650) -> str: 651 self.unsupported("CommentColumnConstraint unsupported") 652 return "" 653 654 655def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 656 self.unsupported("MAP_FROM_ENTRIES unsupported") 657 return "" 658 659 660def str_position_sql( 661 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 662) -> str: 663 this = self.sql(expression, "this") 664 substr = self.sql(expression, "substr") 665 position = self.sql(expression, "position") 666 instance = expression.args.get("instance") if generate_instance else None 667 position_offset = "" 668 669 if position: 670 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 671 this = self.func("SUBSTR", this, position) 672 position_offset = f" + {position} - 1" 673 674 return self.func("STRPOS", this, substr, instance) + position_offset 675 676 677def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 678 return ( 679 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 680 ) 681 682 683def var_map_sql( 684 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 685) -> str: 686 keys = expression.args["keys"] 687 values = expression.args["values"] 688 689 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 690 self.unsupported("Cannot convert array columns into map.") 691 return self.func(map_func_name, keys, values) 692 693 args = [] 694 for key, value in zip(keys.expressions, values.expressions): 695 args.append(self.sql(key)) 696 args.append(self.sql(value)) 697 698 return self.func(map_func_name, *args) 699 700 701def build_formatted_time( 702 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 703) -> t.Callable[[t.List], E]: 704 """Helper used for time expressions. 705 706 Args: 707 exp_class: the expression class to instantiate. 708 dialect: target sql dialect. 709 default: the default format, True being time. 710 711 Returns: 712 A callable that can be used to return the appropriately formatted time expression. 713 """ 714 715 def _builder(args: t.List): 716 return exp_class( 717 this=seq_get(args, 0), 718 format=Dialect[dialect].format_time( 719 seq_get(args, 1) 720 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 721 ), 722 ) 723 724 return _builder 725 726 727def time_format( 728 dialect: DialectType = None, 729) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 730 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 731 """ 732 Returns the time format for a given expression, unless it's equivalent 733 to the default time format of the dialect of interest. 734 """ 735 time_format = self.format_time(expression) 736 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 737 738 return _time_format 739 740 741def build_date_delta( 742 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 743) -> t.Callable[[t.List], E]: 744 def _builder(args: t.List) -> E: 745 unit_based = len(args) == 3 746 this = args[2] if unit_based else seq_get(args, 0) 747 unit = args[0] if unit_based else exp.Literal.string("DAY") 748 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 749 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 750 751 return _builder 752 753 754def build_date_delta_with_interval( 755 expression_class: t.Type[E], 756) -> t.Callable[[t.List], t.Optional[E]]: 757 def _builder(args: t.List) -> t.Optional[E]: 758 if len(args) < 2: 759 return None 760 761 interval = args[1] 762 763 if not isinstance(interval, exp.Interval): 764 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 765 766 expression = interval.this 767 if expression and expression.is_string: 768 expression = exp.Literal.number(expression.this) 769 770 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 771 772 return _builder 773 774 775def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 776 unit = seq_get(args, 0) 777 this = seq_get(args, 1) 778 779 if isinstance(this, exp.Cast) and this.is_type("date"): 780 return exp.DateTrunc(unit=unit, this=this) 781 return exp.TimestampTrunc(this=this, unit=unit) 782 783 784def date_add_interval_sql( 785 data_type: str, kind: str 786) -> t.Callable[[Generator, exp.Expression], str]: 787 def func(self: Generator, expression: exp.Expression) -> str: 788 this = self.sql(expression, "this") 789 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 790 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 791 792 return func 793 794 795def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 796 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 797 args = [unit_to_str(expression), expression.this] 798 if zone: 799 args.append(expression.args.get("zone")) 800 return self.func("DATE_TRUNC", *args) 801 802 return _timestamptrunc_sql 803 804 805def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 806 if not expression.expression: 807 from sqlglot.optimizer.annotate_types import annotate_types 808 809 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 810 return self.sql(exp.cast(expression.this, target_type)) 811 if expression.text("expression").lower() in TIMEZONES: 812 return self.sql( 813 exp.AtTimeZone( 814 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 815 zone=expression.expression, 816 ) 817 ) 818 return self.func("TIMESTAMP", expression.this, expression.expression) 819 820 821def locate_to_strposition(args: t.List) -> exp.Expression: 822 return exp.StrPosition( 823 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 824 ) 825 826 827def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 828 return self.func( 829 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 830 ) 831 832 833def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 834 return self.sql( 835 exp.Substring( 836 this=expression.this, start=exp.Literal.number(1), length=expression.expression 837 ) 838 ) 839 840 841def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 842 return self.sql( 843 exp.Substring( 844 this=expression.this, 845 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 846 ) 847 ) 848 849 850def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 851 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 852 853 854def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 855 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 856 857 858# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 859def encode_decode_sql( 860 self: Generator, expression: exp.Expression, name: str, replace: bool = True 861) -> str: 862 charset = expression.args.get("charset") 863 if charset and charset.name.lower() != "utf-8": 864 self.unsupported(f"Expected utf-8 character set, got {charset}.") 865 866 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 867 868 869def min_or_least(self: Generator, expression: exp.Min) -> str: 870 name = "LEAST" if expression.expressions else "MIN" 871 return rename_func(name)(self, expression) 872 873 874def max_or_greatest(self: Generator, expression: exp.Max) -> str: 875 name = "GREATEST" if expression.expressions else "MAX" 876 return rename_func(name)(self, expression) 877 878 879def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 880 cond = expression.this 881 882 if isinstance(expression.this, exp.Distinct): 883 cond = expression.this.expressions[0] 884 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 885 886 return self.func("sum", exp.func("if", cond, 1, 0)) 887 888 889def trim_sql(self: Generator, expression: exp.Trim) -> str: 890 target = self.sql(expression, "this") 891 trim_type = self.sql(expression, "position") 892 remove_chars = self.sql(expression, "expression") 893 collation = self.sql(expression, "collation") 894 895 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 896 if not remove_chars and not collation: 897 return self.trim_sql(expression) 898 899 trim_type = f"{trim_type} " if trim_type else "" 900 remove_chars = f"{remove_chars} " if remove_chars else "" 901 from_part = "FROM " if trim_type or remove_chars else "" 902 collation = f" COLLATE {collation}" if collation else "" 903 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 904 905 906def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 907 return self.func("STRPTIME", expression.this, self.format_time(expression)) 908 909 910def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 911 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 912 913 914def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 915 delim, *rest_args = expression.expressions 916 return self.sql( 917 reduce( 918 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 919 rest_args, 920 ) 921 ) 922 923 924def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 925 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 926 if bad_args: 927 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 928 929 return self.func( 930 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 931 ) 932 933 934def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 935 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 936 if bad_args: 937 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 938 939 return self.func( 940 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 941 ) 942 943 944def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 945 names = [] 946 for agg in aggregations: 947 if isinstance(agg, exp.Alias): 948 names.append(agg.alias) 949 else: 950 """ 951 This case corresponds to aggregations without aliases being used as suffixes 952 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 953 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 954 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 955 """ 956 agg_all_unquoted = agg.transform( 957 lambda node: ( 958 exp.Identifier(this=node.name, quoted=False) 959 if isinstance(node, exp.Identifier) 960 else node 961 ) 962 ) 963 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 964 965 return names 966 967 968def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 969 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 970 971 972# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 973def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 974 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 975 976 977def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 978 return self.func("MAX", expression.this) 979 980 981def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 982 a = self.sql(expression.left) 983 b = self.sql(expression.right) 984 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 985 986 987def is_parse_json(expression: exp.Expression) -> bool: 988 return isinstance(expression, exp.ParseJSON) or ( 989 isinstance(expression, exp.Cast) and expression.is_type("json") 990 ) 991 992 993def isnull_to_is_null(args: t.List) -> exp.Expression: 994 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 995 996 997def generatedasidentitycolumnconstraint_sql( 998 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 999) -> str: 1000 start = self.sql(expression, "start") or "1" 1001 increment = self.sql(expression, "increment") or "1" 1002 return f"IDENTITY({start}, {increment})" 1003 1004 1005def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1006 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1007 if expression.args.get("count"): 1008 self.unsupported(f"Only two arguments are supported in function {name}.") 1009 1010 return self.func(name, expression.this, expression.expression) 1011 1012 return _arg_max_or_min_sql 1013 1014 1015def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1016 this = expression.this.copy() 1017 1018 return_type = expression.return_type 1019 if return_type.is_type(exp.DataType.Type.DATE): 1020 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1021 # can truncate timestamp strings, because some dialects can't cast them to DATE 1022 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1023 1024 expression.this.replace(exp.cast(this, return_type)) 1025 return expression 1026 1027 1028def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1029 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1030 if cast and isinstance(expression, exp.TsOrDsAdd): 1031 expression = ts_or_ds_add_cast(expression) 1032 1033 return self.func( 1034 name, 1035 unit_to_var(expression), 1036 expression.expression, 1037 expression.this, 1038 ) 1039 1040 return _delta_sql 1041 1042 1043def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1044 unit = expression.args.get("unit") 1045 1046 if isinstance(unit, exp.Placeholder): 1047 return unit 1048 if unit: 1049 return exp.Literal.string(unit.name) 1050 return exp.Literal.string(default) if default else None 1051 1052 1053def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1054 unit = expression.args.get("unit") 1055 1056 if isinstance(unit, (exp.Var, exp.Placeholder)): 1057 return unit 1058 return exp.Var(this=default) if default else None 1059 1060 1061def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1062 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1063 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1064 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1065 1066 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1067 1068 1069def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1070 """Remove table refs from columns in when statements.""" 1071 alias = expression.this.args.get("alias") 1072 1073 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1074 return self.dialect.normalize_identifier(identifier).name if identifier else None 1075 1076 targets = {normalize(expression.this.this)} 1077 1078 if alias: 1079 targets.add(normalize(alias.this)) 1080 1081 for when in expression.expressions: 1082 when.transform( 1083 lambda node: ( 1084 exp.column(node.this) 1085 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1086 else node 1087 ), 1088 copy=False, 1089 ) 1090 1091 return self.merge_sql(expression) 1092 1093 1094def build_json_extract_path( 1095 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1096) -> t.Callable[[t.List], F]: 1097 def _builder(args: t.List) -> F: 1098 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1099 for arg in args[1:]: 1100 if not isinstance(arg, exp.Literal): 1101 # We use the fallback parser because we can't really transpile non-literals safely 1102 return expr_type.from_arg_list(args) 1103 1104 text = arg.name 1105 if is_int(text): 1106 index = int(text) 1107 segments.append( 1108 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1109 ) 1110 else: 1111 segments.append(exp.JSONPathKey(this=text)) 1112 1113 # This is done to avoid failing in the expression validator due to the arg count 1114 del args[2:] 1115 return expr_type( 1116 this=seq_get(args, 0), 1117 expression=exp.JSONPath(expressions=segments), 1118 only_json_types=arrow_req_json_type, 1119 ) 1120 1121 return _builder 1122 1123 1124def json_extract_segments( 1125 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1126) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1127 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1128 path = expression.expression 1129 if not isinstance(path, exp.JSONPath): 1130 return rename_func(name)(self, expression) 1131 1132 segments = [] 1133 for segment in path.expressions: 1134 path = self.sql(segment) 1135 if path: 1136 if isinstance(segment, exp.JSONPathPart) and ( 1137 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1138 ): 1139 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1140 1141 segments.append(path) 1142 1143 if op: 1144 return f" {op} ".join([self.sql(expression.this), *segments]) 1145 return self.func(name, expression.this, *segments) 1146 1147 return _json_extract_segments 1148 1149 1150def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1151 if isinstance(expression.this, exp.JSONPathWildcard): 1152 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1153 1154 return expression.name 1155 1156 1157def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1158 cond = expression.expression 1159 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1160 alias = cond.expressions[0] 1161 cond = cond.this 1162 elif isinstance(cond, exp.Predicate): 1163 alias = "_u" 1164 else: 1165 self.unsupported("Unsupported filter condition") 1166 return "" 1167 1168 unnest = exp.Unnest(expressions=[expression.this]) 1169 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1170 return self.sql(exp.Array(expressions=[filtered])) 1171 1172 1173def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1174 return self.func( 1175 "TO_NUMBER", 1176 expression.this, 1177 expression.args.get("format"), 1178 expression.args.get("nlsparam"), 1179 ) 1180 1181 1182def build_default_decimal_type( 1183 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1184) -> t.Callable[[exp.DataType], exp.DataType]: 1185 def _builder(dtype: exp.DataType) -> exp.DataType: 1186 if dtype.expressions or precision is None: 1187 return dtype 1188 1189 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1190 return exp.DataType.build(f"DECIMAL({params})") 1191 1192 return _builder 1193 1194 1195def build_timestamp_from_parts(args: t.List) -> exp.Func: 1196 if len(args) == 2: 1197 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1198 # so we parse this into Anonymous for now instead of introducing complexity 1199 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1200 1201 return exp.TimestampFromParts.from_arg_list(args) 1202 1203 1204def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1205 return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MATERIALIZE = "materialize" 55 MYSQL = "mysql" 56 ORACLE = "oracle" 57 POSTGRES = "postgres" 58 PRESTO = "presto" 59 PRQL = "prql" 60 REDSHIFT = "redshift" 61 RISINGWAVE = "risingwave" 62 SNOWFLAKE = "snowflake" 63 SPARK = "spark" 64 SPARK2 = "spark2" 65 SQLITE = "sqlite" 66 STARROCKS = "starrocks" 67 TABLEAU = "tableau" 68 TERADATA = "teradata" 69 TRINO = "trino" 70 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
73class NormalizationStrategy(str, AutoName): 74 """Specifies the strategy according to which identifiers should be normalized.""" 75 76 LOWERCASE = auto() 77 """Unquoted identifiers are lowercased.""" 78 79 UPPERCASE = auto() 80 """Unquoted identifiers are uppercased.""" 81 82 CASE_SENSITIVE = auto() 83 """Always case-sensitive, regardless of quotes.""" 84 85 CASE_INSENSITIVE = auto() 86 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
199class Dialect(metaclass=_Dialect): 200 INDEX_OFFSET = 0 201 """The base index offset for arrays.""" 202 203 WEEK_OFFSET = 0 204 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 205 206 UNNEST_COLUMN_ONLY = False 207 """Whether `UNNEST` table aliases are treated as column aliases.""" 208 209 ALIAS_POST_TABLESAMPLE = False 210 """Whether the table alias comes after tablesample.""" 211 212 TABLESAMPLE_SIZE_IS_PERCENT = False 213 """Whether a size in the table sample clause represents percentage.""" 214 215 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 216 """Specifies the strategy according to which identifiers should be normalized.""" 217 218 IDENTIFIERS_CAN_START_WITH_DIGIT = False 219 """Whether an unquoted identifier can start with a digit.""" 220 221 DPIPE_IS_STRING_CONCAT = True 222 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 223 224 STRICT_STRING_CONCAT = False 225 """Whether `CONCAT`'s arguments must be strings.""" 226 227 SUPPORTS_USER_DEFINED_TYPES = True 228 """Whether user-defined data types are supported.""" 229 230 SUPPORTS_SEMI_ANTI_JOIN = True 231 """Whether `SEMI` or `ANTI` joins are supported.""" 232 233 SUPPORTS_COLUMN_JOIN_MARKS = False 234 """Whether the old-style outer join (+) syntax is supported.""" 235 236 NORMALIZE_FUNCTIONS: bool | str = "upper" 237 """ 238 Determines how function names are going to be normalized. 239 Possible values: 240 "upper" or True: Convert names to uppercase. 241 "lower": Convert names to lowercase. 242 False: Disables function name normalization. 243 """ 244 245 LOG_BASE_FIRST: t.Optional[bool] = True 246 """ 247 Whether the base comes first in the `LOG` function. 248 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 249 """ 250 251 NULL_ORDERING = "nulls_are_small" 252 """ 253 Default `NULL` ordering method to use if not explicitly set. 254 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 255 """ 256 257 TYPED_DIVISION = False 258 """ 259 Whether the behavior of `a / b` depends on the types of `a` and `b`. 260 False means `a / b` is always float division. 261 True means `a / b` is integer division if both `a` and `b` are integers. 262 """ 263 264 SAFE_DIVISION = False 265 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 266 267 CONCAT_COALESCE = False 268 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 269 270 HEX_LOWERCASE = False 271 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 272 273 DATE_FORMAT = "'%Y-%m-%d'" 274 DATEINT_FORMAT = "'%Y%m%d'" 275 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 276 277 TIME_MAPPING: t.Dict[str, str] = {} 278 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 279 280 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 281 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 282 FORMAT_MAPPING: t.Dict[str, str] = {} 283 """ 284 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 285 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 286 """ 287 288 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 289 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 290 291 PSEUDOCOLUMNS: t.Set[str] = set() 292 """ 293 Columns that are auto-generated by the engine corresponding to this dialect. 294 For example, such columns may be excluded from `SELECT *` queries. 295 """ 296 297 PREFER_CTE_ALIAS_COLUMN = False 298 """ 299 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 300 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 301 any projection aliases in the subquery. 302 303 For example, 304 WITH y(c) AS ( 305 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 306 ) SELECT c FROM y; 307 308 will be rewritten as 309 310 WITH y(c) AS ( 311 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 312 ) SELECT c FROM y; 313 """ 314 315 # --- Autofilled --- 316 317 tokenizer_class = Tokenizer 318 parser_class = Parser 319 generator_class = Generator 320 321 # A trie of the time_mapping keys 322 TIME_TRIE: t.Dict = {} 323 FORMAT_TRIE: t.Dict = {} 324 325 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 326 INVERSE_TIME_TRIE: t.Dict = {} 327 328 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 329 330 # Delimiters for string literals and identifiers 331 QUOTE_START = "'" 332 QUOTE_END = "'" 333 IDENTIFIER_START = '"' 334 IDENTIFIER_END = '"' 335 336 # Delimiters for bit, hex, byte and unicode literals 337 BIT_START: t.Optional[str] = None 338 BIT_END: t.Optional[str] = None 339 HEX_START: t.Optional[str] = None 340 HEX_END: t.Optional[str] = None 341 BYTE_START: t.Optional[str] = None 342 BYTE_END: t.Optional[str] = None 343 UNICODE_START: t.Optional[str] = None 344 UNICODE_END: t.Optional[str] = None 345 346 # Separator of COPY statement parameters 347 COPY_PARAMS_ARE_CSV = True 348 349 @classmethod 350 def get_or_raise(cls, dialect: DialectType) -> Dialect: 351 """ 352 Look up a dialect in the global dialect registry and return it if it exists. 353 354 Args: 355 dialect: The target dialect. If this is a string, it can be optionally followed by 356 additional key-value pairs that are separated by commas and are used to specify 357 dialect settings, such as whether the dialect's identifiers are case-sensitive. 358 359 Example: 360 >>> dialect = dialect_class = get_or_raise("duckdb") 361 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 362 363 Returns: 364 The corresponding Dialect instance. 365 """ 366 367 if not dialect: 368 return cls() 369 if isinstance(dialect, _Dialect): 370 return dialect() 371 if isinstance(dialect, Dialect): 372 return dialect 373 if isinstance(dialect, str): 374 try: 375 dialect_name, *kv_pairs = dialect.split(",") 376 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 377 except ValueError: 378 raise ValueError( 379 f"Invalid dialect format: '{dialect}'. " 380 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 381 ) 382 383 result = cls.get(dialect_name.strip()) 384 if not result: 385 from difflib import get_close_matches 386 387 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 388 if similar: 389 similar = f" Did you mean {similar}?" 390 391 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 392 393 return result(**kwargs) 394 395 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 396 397 @classmethod 398 def format_time( 399 cls, expression: t.Optional[str | exp.Expression] 400 ) -> t.Optional[exp.Expression]: 401 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 402 if isinstance(expression, str): 403 return exp.Literal.string( 404 # the time formats are quoted 405 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 406 ) 407 408 if expression and expression.is_string: 409 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 410 411 return expression 412 413 def __init__(self, **kwargs) -> None: 414 normalization_strategy = kwargs.get("normalization_strategy") 415 416 if normalization_strategy is None: 417 self.normalization_strategy = self.NORMALIZATION_STRATEGY 418 else: 419 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 420 421 def __eq__(self, other: t.Any) -> bool: 422 # Does not currently take dialect state into account 423 return type(self) == other 424 425 def __hash__(self) -> int: 426 # Does not currently take dialect state into account 427 return hash(type(self)) 428 429 def normalize_identifier(self, expression: E) -> E: 430 """ 431 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 432 433 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 434 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 435 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 436 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 437 438 There are also dialects like Spark, which are case-insensitive even when quotes are 439 present, and dialects like MySQL, whose resolution rules match those employed by the 440 underlying operating system, for example they may always be case-sensitive in Linux. 441 442 Finally, the normalization behavior of some engines can even be controlled through flags, 443 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 444 445 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 446 that it can analyze queries in the optimizer and successfully capture their semantics. 447 """ 448 if ( 449 isinstance(expression, exp.Identifier) 450 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 451 and ( 452 not expression.quoted 453 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 454 ) 455 ): 456 expression.set( 457 "this", 458 ( 459 expression.this.upper() 460 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 461 else expression.this.lower() 462 ), 463 ) 464 465 return expression 466 467 def case_sensitive(self, text: str) -> bool: 468 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 469 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 470 return False 471 472 unsafe = ( 473 str.islower 474 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 475 else str.isupper 476 ) 477 return any(unsafe(char) for char in text) 478 479 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 480 """Checks if text can be identified given an identify option. 481 482 Args: 483 text: The text to check. 484 identify: 485 `"always"` or `True`: Always returns `True`. 486 `"safe"`: Only returns `True` if the identifier is case-insensitive. 487 488 Returns: 489 Whether the given text can be identified. 490 """ 491 if identify is True or identify == "always": 492 return True 493 494 if identify == "safe": 495 return not self.case_sensitive(text) 496 497 return False 498 499 def quote_identifier(self, expression: E, identify: bool = True) -> E: 500 """ 501 Adds quotes to a given identifier. 502 503 Args: 504 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 505 identify: If set to `False`, the quotes will only be added if the identifier is deemed 506 "unsafe", with respect to its characters and this dialect's normalization strategy. 507 """ 508 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 509 name = expression.this 510 expression.set( 511 "quoted", 512 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 513 ) 514 515 return expression 516 517 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 518 if isinstance(path, exp.Literal): 519 path_text = path.name 520 if path.is_number: 521 path_text = f"[{path_text}]" 522 523 try: 524 return parse_json_path(path_text) 525 except ParseError as e: 526 logger.warning(f"Invalid JSON path syntax. {str(e)}") 527 528 return path 529 530 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 531 return self.parser(**opts).parse(self.tokenize(sql), sql) 532 533 def parse_into( 534 self, expression_type: exp.IntoType, sql: str, **opts 535 ) -> t.List[t.Optional[exp.Expression]]: 536 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 537 538 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 539 return self.generator(**opts).generate(expression, copy=copy) 540 541 def transpile(self, sql: str, **opts) -> t.List[str]: 542 return [ 543 self.generate(expression, copy=False, **opts) if expression else "" 544 for expression in self.parse(sql) 545 ] 546 547 def tokenize(self, sql: str) -> t.List[Token]: 548 return self.tokenizer.tokenize(sql) 549 550 @property 551 def tokenizer(self) -> Tokenizer: 552 if not hasattr(self, "_tokenizer"): 553 self._tokenizer = self.tokenizer_class(dialect=self) 554 return self._tokenizer 555 556 def parser(self, **opts) -> Parser: 557 return self.parser_class(dialect=self, **opts) 558 559 def generator(self, **opts) -> Generator: 560 return self.generator_class(dialect=self, **opts)
413 def __init__(self, **kwargs) -> None: 414 normalization_strategy = kwargs.get("normalization_strategy") 415 416 if normalization_strategy is None: 417 self.normalization_strategy = self.NORMALIZATION_STRATEGY 418 else: 419 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
349 @classmethod 350 def get_or_raise(cls, dialect: DialectType) -> Dialect: 351 """ 352 Look up a dialect in the global dialect registry and return it if it exists. 353 354 Args: 355 dialect: The target dialect. If this is a string, it can be optionally followed by 356 additional key-value pairs that are separated by commas and are used to specify 357 dialect settings, such as whether the dialect's identifiers are case-sensitive. 358 359 Example: 360 >>> dialect = dialect_class = get_or_raise("duckdb") 361 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 362 363 Returns: 364 The corresponding Dialect instance. 365 """ 366 367 if not dialect: 368 return cls() 369 if isinstance(dialect, _Dialect): 370 return dialect() 371 if isinstance(dialect, Dialect): 372 return dialect 373 if isinstance(dialect, str): 374 try: 375 dialect_name, *kv_pairs = dialect.split(",") 376 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 377 except ValueError: 378 raise ValueError( 379 f"Invalid dialect format: '{dialect}'. " 380 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 381 ) 382 383 result = cls.get(dialect_name.strip()) 384 if not result: 385 from difflib import get_close_matches 386 387 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 388 if similar: 389 similar = f" Did you mean {similar}?" 390 391 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 392 393 return result(**kwargs) 394 395 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
397 @classmethod 398 def format_time( 399 cls, expression: t.Optional[str | exp.Expression] 400 ) -> t.Optional[exp.Expression]: 401 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 402 if isinstance(expression, str): 403 return exp.Literal.string( 404 # the time formats are quoted 405 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 406 ) 407 408 if expression and expression.is_string: 409 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 410 411 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
429 def normalize_identifier(self, expression: E) -> E: 430 """ 431 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 432 433 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 434 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 435 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 436 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 437 438 There are also dialects like Spark, which are case-insensitive even when quotes are 439 present, and dialects like MySQL, whose resolution rules match those employed by the 440 underlying operating system, for example they may always be case-sensitive in Linux. 441 442 Finally, the normalization behavior of some engines can even be controlled through flags, 443 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 444 445 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 446 that it can analyze queries in the optimizer and successfully capture their semantics. 447 """ 448 if ( 449 isinstance(expression, exp.Identifier) 450 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 451 and ( 452 not expression.quoted 453 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 454 ) 455 ): 456 expression.set( 457 "this", 458 ( 459 expression.this.upper() 460 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 461 else expression.this.lower() 462 ), 463 ) 464 465 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
467 def case_sensitive(self, text: str) -> bool: 468 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 469 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 470 return False 471 472 unsafe = ( 473 str.islower 474 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 475 else str.isupper 476 ) 477 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
479 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 480 """Checks if text can be identified given an identify option. 481 482 Args: 483 text: The text to check. 484 identify: 485 `"always"` or `True`: Always returns `True`. 486 `"safe"`: Only returns `True` if the identifier is case-insensitive. 487 488 Returns: 489 Whether the given text can be identified. 490 """ 491 if identify is True or identify == "always": 492 return True 493 494 if identify == "safe": 495 return not self.case_sensitive(text) 496 497 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
499 def quote_identifier(self, expression: E, identify: bool = True) -> E: 500 """ 501 Adds quotes to a given identifier. 502 503 Args: 504 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 505 identify: If set to `False`, the quotes will only be added if the identifier is deemed 506 "unsafe", with respect to its characters and this dialect's normalization strategy. 507 """ 508 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 509 name = expression.this 510 expression.set( 511 "quoted", 512 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 513 ) 514 515 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
517 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 518 if isinstance(path, exp.Literal): 519 path_text = path.name 520 if path.is_number: 521 path_text = f"[{path_text}]" 522 523 try: 524 return parse_json_path(path_text) 525 except ParseError as e: 526 logger.warning(f"Invalid JSON path syntax. {str(e)}") 527 528 return path
576def if_sql( 577 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 578) -> t.Callable[[Generator, exp.If], str]: 579 def _if_sql(self: Generator, expression: exp.If) -> str: 580 return self.func( 581 name, 582 expression.this, 583 expression.args.get("true"), 584 expression.args.get("false") or false_value, 585 ) 586 587 return _if_sql
590def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 591 this = expression.this 592 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 593 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 594 595 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
661def str_position_sql( 662 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 663) -> str: 664 this = self.sql(expression, "this") 665 substr = self.sql(expression, "substr") 666 position = self.sql(expression, "position") 667 instance = expression.args.get("instance") if generate_instance else None 668 position_offset = "" 669 670 if position: 671 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 672 this = self.func("SUBSTR", this, position) 673 position_offset = f" + {position} - 1" 674 675 return self.func("STRPOS", this, substr, instance) + position_offset
684def var_map_sql( 685 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 686) -> str: 687 keys = expression.args["keys"] 688 values = expression.args["values"] 689 690 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 691 self.unsupported("Cannot convert array columns into map.") 692 return self.func(map_func_name, keys, values) 693 694 args = [] 695 for key, value in zip(keys.expressions, values.expressions): 696 args.append(self.sql(key)) 697 args.append(self.sql(value)) 698 699 return self.func(map_func_name, *args)
702def build_formatted_time( 703 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 704) -> t.Callable[[t.List], E]: 705 """Helper used for time expressions. 706 707 Args: 708 exp_class: the expression class to instantiate. 709 dialect: target sql dialect. 710 default: the default format, True being time. 711 712 Returns: 713 A callable that can be used to return the appropriately formatted time expression. 714 """ 715 716 def _builder(args: t.List): 717 return exp_class( 718 this=seq_get(args, 0), 719 format=Dialect[dialect].format_time( 720 seq_get(args, 1) 721 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 722 ), 723 ) 724 725 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
728def time_format( 729 dialect: DialectType = None, 730) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 731 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 732 """ 733 Returns the time format for a given expression, unless it's equivalent 734 to the default time format of the dialect of interest. 735 """ 736 time_format = self.format_time(expression) 737 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 738 739 return _time_format
742def build_date_delta( 743 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 744) -> t.Callable[[t.List], E]: 745 def _builder(args: t.List) -> E: 746 unit_based = len(args) == 3 747 this = args[2] if unit_based else seq_get(args, 0) 748 unit = args[0] if unit_based else exp.Literal.string("DAY") 749 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 750 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 751 752 return _builder
755def build_date_delta_with_interval( 756 expression_class: t.Type[E], 757) -> t.Callable[[t.List], t.Optional[E]]: 758 def _builder(args: t.List) -> t.Optional[E]: 759 if len(args) < 2: 760 return None 761 762 interval = args[1] 763 764 if not isinstance(interval, exp.Interval): 765 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 766 767 expression = interval.this 768 if expression and expression.is_string: 769 expression = exp.Literal.number(expression.this) 770 771 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 772 773 return _builder
785def date_add_interval_sql( 786 data_type: str, kind: str 787) -> t.Callable[[Generator, exp.Expression], str]: 788 def func(self: Generator, expression: exp.Expression) -> str: 789 this = self.sql(expression, "this") 790 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 791 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 792 793 return func
796def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 797 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 798 args = [unit_to_str(expression), expression.this] 799 if zone: 800 args.append(expression.args.get("zone")) 801 return self.func("DATE_TRUNC", *args) 802 803 return _timestamptrunc_sql
806def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 807 if not expression.expression: 808 from sqlglot.optimizer.annotate_types import annotate_types 809 810 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 811 return self.sql(exp.cast(expression.this, target_type)) 812 if expression.text("expression").lower() in TIMEZONES: 813 return self.sql( 814 exp.AtTimeZone( 815 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 816 zone=expression.expression, 817 ) 818 ) 819 return self.func("TIMESTAMP", expression.this, expression.expression)
860def encode_decode_sql( 861 self: Generator, expression: exp.Expression, name: str, replace: bool = True 862) -> str: 863 charset = expression.args.get("charset") 864 if charset and charset.name.lower() != "utf-8": 865 self.unsupported(f"Expected utf-8 character set, got {charset}.") 866 867 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
880def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 881 cond = expression.this 882 883 if isinstance(expression.this, exp.Distinct): 884 cond = expression.this.expressions[0] 885 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 886 887 return self.func("sum", exp.func("if", cond, 1, 0))
890def trim_sql(self: Generator, expression: exp.Trim) -> str: 891 target = self.sql(expression, "this") 892 trim_type = self.sql(expression, "position") 893 remove_chars = self.sql(expression, "expression") 894 collation = self.sql(expression, "collation") 895 896 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 897 if not remove_chars and not collation: 898 return self.trim_sql(expression) 899 900 trim_type = f"{trim_type} " if trim_type else "" 901 remove_chars = f"{remove_chars} " if remove_chars else "" 902 from_part = "FROM " if trim_type or remove_chars else "" 903 collation = f" COLLATE {collation}" if collation else "" 904 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
925def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 926 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 927 if bad_args: 928 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 929 930 return self.func( 931 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 932 )
935def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 936 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 937 if bad_args: 938 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 939 940 return self.func( 941 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 942 )
945def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 946 names = [] 947 for agg in aggregations: 948 if isinstance(agg, exp.Alias): 949 names.append(agg.alias) 950 else: 951 """ 952 This case corresponds to aggregations without aliases being used as suffixes 953 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 954 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 955 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 956 """ 957 agg_all_unquoted = agg.transform( 958 lambda node: ( 959 exp.Identifier(this=node.name, quoted=False) 960 if isinstance(node, exp.Identifier) 961 else node 962 ) 963 ) 964 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 965 966 return names
1006def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1007 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1008 if expression.args.get("count"): 1009 self.unsupported(f"Only two arguments are supported in function {name}.") 1010 1011 return self.func(name, expression.this, expression.expression) 1012 1013 return _arg_max_or_min_sql
1016def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1017 this = expression.this.copy() 1018 1019 return_type = expression.return_type 1020 if return_type.is_type(exp.DataType.Type.DATE): 1021 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1022 # can truncate timestamp strings, because some dialects can't cast them to DATE 1023 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1024 1025 expression.this.replace(exp.cast(this, return_type)) 1026 return expression
1029def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1030 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1031 if cast and isinstance(expression, exp.TsOrDsAdd): 1032 expression = ts_or_ds_add_cast(expression) 1033 1034 return self.func( 1035 name, 1036 unit_to_var(expression), 1037 expression.expression, 1038 expression.this, 1039 ) 1040 1041 return _delta_sql
1044def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1045 unit = expression.args.get("unit") 1046 1047 if isinstance(unit, exp.Placeholder): 1048 return unit 1049 if unit: 1050 return exp.Literal.string(unit.name) 1051 return exp.Literal.string(default) if default else None
1062def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1063 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1064 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1065 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1066 1067 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1070def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1071 """Remove table refs from columns in when statements.""" 1072 alias = expression.this.args.get("alias") 1073 1074 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1075 return self.dialect.normalize_identifier(identifier).name if identifier else None 1076 1077 targets = {normalize(expression.this.this)} 1078 1079 if alias: 1080 targets.add(normalize(alias.this)) 1081 1082 for when in expression.expressions: 1083 when.transform( 1084 lambda node: ( 1085 exp.column(node.this) 1086 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1087 else node 1088 ), 1089 copy=False, 1090 ) 1091 1092 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1095def build_json_extract_path( 1096 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1097) -> t.Callable[[t.List], F]: 1098 def _builder(args: t.List) -> F: 1099 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1100 for arg in args[1:]: 1101 if not isinstance(arg, exp.Literal): 1102 # We use the fallback parser because we can't really transpile non-literals safely 1103 return expr_type.from_arg_list(args) 1104 1105 text = arg.name 1106 if is_int(text): 1107 index = int(text) 1108 segments.append( 1109 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1110 ) 1111 else: 1112 segments.append(exp.JSONPathKey(this=text)) 1113 1114 # This is done to avoid failing in the expression validator due to the arg count 1115 del args[2:] 1116 return expr_type( 1117 this=seq_get(args, 0), 1118 expression=exp.JSONPath(expressions=segments), 1119 only_json_types=arrow_req_json_type, 1120 ) 1121 1122 return _builder
1125def json_extract_segments( 1126 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1127) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1128 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1129 path = expression.expression 1130 if not isinstance(path, exp.JSONPath): 1131 return rename_func(name)(self, expression) 1132 1133 segments = [] 1134 for segment in path.expressions: 1135 path = self.sql(segment) 1136 if path: 1137 if isinstance(segment, exp.JSONPathPart) and ( 1138 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1139 ): 1140 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1141 1142 segments.append(path) 1143 1144 if op: 1145 return f" {op} ".join([self.sql(expression.this), *segments]) 1146 return self.func(name, expression.this, *segments) 1147 1148 return _json_extract_segments
1158def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1159 cond = expression.expression 1160 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1161 alias = cond.expressions[0] 1162 cond = cond.this 1163 elif isinstance(cond, exp.Predicate): 1164 alias = "_u" 1165 else: 1166 self.unsupported("Unsupported filter condition") 1167 return "" 1168 1169 unnest = exp.Unnest(expressions=[expression.this]) 1170 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1171 return self.sql(exp.Array(expressions=[filtered]))
1183def build_default_decimal_type( 1184 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1185) -> t.Callable[[exp.DataType], exp.DataType]: 1186 def _builder(dtype: exp.DataType) -> exp.DataType: 1187 if dtype.expressions or precision is None: 1188 return dtype 1189 1190 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1191 return exp.DataType.build(f"DECIMAL({params})") 1192 1193 return _builder
1196def build_timestamp_from_parts(args: t.List) -> exp.Func: 1197 if len(args) == 2: 1198 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1199 # so we parse this into Anonymous for now instead of introducing complexity 1200 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1201 1202 return exp.TimestampFromParts.from_arg_list(args)