Fixed an issue that prevented the password reset tokens from working. Added email templates for password reset success and new account creation. Added more dynamic email template support.
1542 lines
49 KiB
Python
1542 lines
49 KiB
Python
# Copyright (c) 2016, 2022, Oracle and/or its affiliates.
|
|
#
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License, version 2.0, as
|
|
# published by the Free Software Foundation.
|
|
#
|
|
# This program is also distributed with certain software (including
|
|
# but not limited to OpenSSL) that is licensed under separate terms,
|
|
# as designated in a particular file or component or in included license
|
|
# documentation. The authors of MySQL hereby grant you an
|
|
# additional permission to link the program and your derivative works
|
|
# with the separately licensed software that they have included with
|
|
# MySQL.
|
|
#
|
|
# Without limiting anything contained in the foregoing, this file,
|
|
# which is part of MySQL Connector/Python, is also subject to the
|
|
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
|
# http://oss.oracle.com/licenses/universal-foss-exception.
|
|
#
|
|
# This program is distributed in the hope that it will be useful, but
|
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
|
# See the GNU General Public License, version 2.0, for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
|
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
|
|
|
# mypy: disable-error-code="return-value"
|
|
|
|
"""Implementation of Statements."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import json
|
|
import warnings
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
from .constants import LockContention
|
|
from .dbdoc import DbDoc
|
|
from .errors import NotSupportedError, ProgrammingError
|
|
from .expr import ExprParser
|
|
from .helpers import deprecated
|
|
from .protobuf import mysqlxpb_enum
|
|
from .result import DocResult, Result, RowResult, SqlResult
|
|
from .types import (
|
|
ConnectionType,
|
|
DatabaseTargetType,
|
|
MessageType,
|
|
ProtobufMessageCextType,
|
|
ProtobufMessageType,
|
|
SchemaType,
|
|
)
|
|
|
|
ERR_INVALID_INDEX_NAME = 'The given index name "{}" is not valid'
|
|
|
|
|
|
class Expr:
|
|
"""Expression wrapper."""
|
|
|
|
def __init__(self, expr: Any) -> None:
|
|
self.expr: Any = expr
|
|
|
|
|
|
def flexible_params(*values: Any) -> Union[List, Tuple]:
|
|
"""Parse flexible parameters."""
|
|
if len(values) == 1 and isinstance(values[0], (list, tuple)):
|
|
return values[0]
|
|
return values
|
|
|
|
|
|
def is_quoted_identifier(identifier: str, sql_mode: str = "") -> bool:
|
|
"""Check if the given identifier is quoted.
|
|
|
|
Args:
|
|
identifier (string): Identifier to check.
|
|
sql_mode (Optional[string]): SQL mode.
|
|
|
|
Returns:
|
|
`True` if the identifier has backtick quotes, and False otherwise.
|
|
"""
|
|
if "ANSI_QUOTES" in sql_mode:
|
|
return (identifier[0] == "`" and identifier[-1] == "`") or (
|
|
identifier[0] == '"' and identifier[-1] == '"'
|
|
)
|
|
return identifier[0] == "`" and identifier[-1] == "`"
|
|
|
|
|
|
def quote_identifier(identifier: str, sql_mode: str = "") -> str:
|
|
"""Quote the given identifier with backticks, converting backticks (`) in
|
|
the identifier name with the correct escape sequence (``).
|
|
|
|
Args:
|
|
identifier (string): Identifier to quote.
|
|
sql_mode (Optional[string]): SQL mode.
|
|
|
|
Returns:
|
|
A string with the identifier quoted with backticks.
|
|
"""
|
|
if len(identifier) == 0:
|
|
return "``"
|
|
if "ANSI_QUOTES" in sql_mode:
|
|
quoted = identifier.replace('"', '""')
|
|
return f'"{quoted}"'
|
|
quoted = identifier.replace("`", "``")
|
|
return f"`{quoted}`"
|
|
|
|
|
|
def quote_multipart_identifier(identifiers: Iterable[str], sql_mode: str = "") -> str:
|
|
"""Quote the given multi-part identifier with backticks.
|
|
|
|
Args:
|
|
identifiers (iterable): List of identifiers to quote.
|
|
sql_mode (Optional[string]): SQL mode.
|
|
|
|
Returns:
|
|
A string with the multi-part identifier quoted with backticks.
|
|
"""
|
|
return ".".join(
|
|
[quote_identifier(identifier, sql_mode) for identifier in identifiers]
|
|
)
|
|
|
|
|
|
def parse_table_name(
|
|
default_schema: str, table_name: str, sql_mode: str = ""
|
|
) -> Tuple[str, str]:
|
|
"""Parse table name.
|
|
|
|
Args:
|
|
default_schema (str): The default schema.
|
|
table_name (str): The table name.
|
|
sql_mode(Optional[str]): The SQL mode.
|
|
|
|
Returns:
|
|
str: The parsed table name.
|
|
"""
|
|
quote = '"' if "ANSI_QUOTES" in sql_mode else "`"
|
|
delimiter = f".{quote}" if quote in table_name else "."
|
|
temp = table_name.split(delimiter, 1)
|
|
return (
|
|
default_schema if len(temp) == 1 else temp[0].strip(quote),
|
|
temp[-1].strip(quote),
|
|
)
|
|
|
|
|
|
class Statement:
|
|
"""Provides base functionality for statement objects.
|
|
|
|
Args:
|
|
target (object): The target database object, it can be
|
|
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
|
|
doc_based (bool): `True` if it is document based.
|
|
"""
|
|
|
|
def __init__(self, target: DatabaseTargetType, doc_based: bool = True) -> None:
|
|
self._target: DatabaseTargetType = target
|
|
self._doc_based: bool = doc_based
|
|
self._connection: Optional[ConnectionType] = (
|
|
target.get_connection() if target else None
|
|
)
|
|
self._stmt_id: Optional[int] = None
|
|
self._exec_counter: int = 0
|
|
self._changed: bool = True
|
|
self._prepared: bool = False
|
|
self._deallocate_prepare_execute: bool = False
|
|
|
|
@property
|
|
def target(self) -> DatabaseTargetType:
|
|
"""object: The database object target."""
|
|
return self._target
|
|
|
|
@property
|
|
def schema(self) -> SchemaType:
|
|
""":class:`mysqlx.Schema`: The Schema object."""
|
|
return self._target.schema
|
|
|
|
@property
|
|
def stmt_id(self) -> int:
|
|
"""Returns this statement ID.
|
|
|
|
Returns:
|
|
int: The statement ID.
|
|
"""
|
|
return self._stmt_id
|
|
|
|
@stmt_id.setter
|
|
def stmt_id(self, value: int) -> None:
|
|
self._stmt_id = value
|
|
|
|
@property
|
|
def exec_counter(self) -> int:
|
|
"""int: The number of times this statement was executed."""
|
|
return self._exec_counter
|
|
|
|
@property
|
|
def changed(self) -> bool:
|
|
"""bool: `True` if this statement has changes."""
|
|
return self._changed
|
|
|
|
@changed.setter
|
|
def changed(self, value: bool) -> None:
|
|
self._changed = value
|
|
|
|
@property
|
|
def prepared(self) -> bool:
|
|
"""bool: `True` if this statement has been prepared."""
|
|
return self._prepared
|
|
|
|
@prepared.setter
|
|
def prepared(self, value: bool) -> None:
|
|
self._prepared = value
|
|
|
|
@property
|
|
def repeated(self) -> bool:
|
|
"""bool: `True` if this statement was executed more than once."""
|
|
return self._exec_counter > 1
|
|
|
|
@property
|
|
def deallocate_prepare_execute(self) -> bool:
|
|
"""bool: `True` to deallocate + prepare + execute statement."""
|
|
return self._deallocate_prepare_execute
|
|
|
|
@deallocate_prepare_execute.setter
|
|
def deallocate_prepare_execute(self, value: bool) -> None:
|
|
self._deallocate_prepare_execute = value
|
|
|
|
def is_doc_based(self) -> bool:
|
|
"""Check if it is document based.
|
|
|
|
Returns:
|
|
bool: `True` if it is document based.
|
|
"""
|
|
return self._doc_based
|
|
|
|
def increment_exec_counter(self) -> None:
|
|
"""Increments the number of times this statement has been executed."""
|
|
self._exec_counter += 1
|
|
|
|
def reset_exec_counter(self) -> None:
|
|
"""Resets the number of times this statement has been executed."""
|
|
self._exec_counter = 0
|
|
|
|
def execute(self) -> Any:
|
|
"""Execute the statement.
|
|
|
|
Raises:
|
|
NotImplementedError: This method must be implemented.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class FilterableStatement(Statement):
|
|
"""A statement to be used with filterable statements.
|
|
|
|
Args:
|
|
target (object): The target database object, it can be
|
|
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
|
|
doc_based (Optional[bool]): `True` if it is document based
|
|
(default: `True`).
|
|
condition (Optional[str]): Sets the search condition to filter
|
|
documents or records.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
target: DatabaseTargetType,
|
|
doc_based: bool = True,
|
|
condition: Optional[str] = None,
|
|
) -> None:
|
|
super().__init__(target=target, doc_based=doc_based)
|
|
self._binding_map: Dict[str, Any] = {}
|
|
self._bindings: Union[Dict[str, Any], List] = {}
|
|
self._having: Optional[MessageType] = None
|
|
self._grouping_str: str = ""
|
|
self._grouping: Optional[
|
|
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
|
|
] = None
|
|
self._limit_offset: int = 0
|
|
self._limit_row_count: int = None
|
|
self._projection_str: str = ""
|
|
self._projection_expr: Optional[
|
|
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
|
|
] = None
|
|
self._sort_str: str = ""
|
|
self._sort_expr: Optional[
|
|
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
|
|
] = None
|
|
self._where_str: str = ""
|
|
self._where_expr: MessageType = None
|
|
self.has_bindings: bool = False
|
|
self.has_limit: bool = False
|
|
self.has_group_by: bool = False
|
|
self.has_having: bool = False
|
|
self.has_projection: bool = False
|
|
self.has_sort: bool = False
|
|
self.has_where: bool = False
|
|
if condition:
|
|
self._set_where(condition)
|
|
|
|
def _bind_single(self, obj: Union[DbDoc, Dict[str, Any], str]) -> None:
|
|
"""Bind single object.
|
|
|
|
Args:
|
|
obj (:class:`mysqlx.DbDoc` or str): DbDoc or JSON string object.
|
|
|
|
Raises:
|
|
:class:`mysqlx.ProgrammingError`: If invalid JSON string to bind.
|
|
ValueError: If JSON loaded is not a dictionary.
|
|
"""
|
|
if isinstance(obj, dict):
|
|
self.bind(DbDoc(obj).as_str())
|
|
elif isinstance(obj, DbDoc):
|
|
self.bind(obj.as_str())
|
|
elif isinstance(obj, str):
|
|
try:
|
|
res = json.loads(obj)
|
|
if not isinstance(res, dict):
|
|
raise ValueError
|
|
except ValueError as err:
|
|
raise ProgrammingError("Invalid JSON string to bind") from err
|
|
for key in res.keys():
|
|
self.bind(key, res[key])
|
|
else:
|
|
raise ProgrammingError("Invalid JSON string or object to bind")
|
|
|
|
def _sort(self, *clauses: str) -> FilterableStatement:
|
|
"""Sets the sorting criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the sort criteria.
|
|
|
|
Returns:
|
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
"""
|
|
self.has_sort = True
|
|
self._sort_str = ",".join(flexible_params(*clauses))
|
|
self._sort_expr = ExprParser(
|
|
self._sort_str, not self._doc_based
|
|
).parse_order_spec()
|
|
self._changed = True
|
|
return self
|
|
|
|
def _set_where(self, condition: str) -> FilterableStatement:
|
|
"""Sets the search condition to filter.
|
|
|
|
Args:
|
|
condition (str): Sets the search condition to filter documents or
|
|
records.
|
|
|
|
Returns:
|
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
"""
|
|
self.has_where = True
|
|
self._where_str = condition
|
|
try:
|
|
expr = ExprParser(condition, not self._doc_based)
|
|
self._where_expr = expr.expr()
|
|
except ValueError as err:
|
|
raise ProgrammingError("Invalid condition") from err
|
|
self._binding_map = expr.placeholder_name_to_position
|
|
self._changed = True
|
|
return self
|
|
|
|
def _set_group_by(self, *fields: str) -> None:
|
|
"""Set group by.
|
|
|
|
Args:
|
|
*fields: List of fields.
|
|
"""
|
|
fields = flexible_params(*fields)
|
|
self.has_group_by = True
|
|
self._grouping_str = ",".join(fields)
|
|
self._grouping = ExprParser(
|
|
self._grouping_str, not self._doc_based
|
|
).parse_expr_list()
|
|
self._changed = True
|
|
|
|
def _set_having(self, condition: str) -> None:
|
|
"""Set having.
|
|
|
|
Args:
|
|
condition (str): The condition.
|
|
"""
|
|
self.has_having = True
|
|
self._having = ExprParser(condition, not self._doc_based).expr()
|
|
self._changed = True
|
|
|
|
def _set_projection(self, *fields: str) -> FilterableStatement:
|
|
"""Set the projection.
|
|
|
|
Args:
|
|
*fields: List of fields.
|
|
|
|
Returns:
|
|
:class:`mysqlx.FilterableStatement`: Returns self.
|
|
"""
|
|
fields = flexible_params(*fields)
|
|
self.has_projection = True
|
|
self._projection_str = ",".join(fields)
|
|
self._projection_expr = ExprParser(
|
|
self._projection_str, not self._doc_based
|
|
).parse_table_select_projection()
|
|
self._changed = True
|
|
return self
|
|
|
|
def get_binding_map(self) -> Dict[str, Any]:
|
|
"""Returns the binding map dictionary.
|
|
|
|
Returns:
|
|
dict: The binding map dictionary.
|
|
"""
|
|
return self._binding_map
|
|
|
|
def get_bindings(self) -> Union[Dict[str, Any], List]:
|
|
"""Returns the bindings list.
|
|
|
|
Returns:
|
|
`list`: The bindings list.
|
|
"""
|
|
return self._bindings
|
|
|
|
def get_grouping(self) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
|
|
"""Returns the grouping expression list.
|
|
|
|
Returns:
|
|
`list`: The grouping expression list.
|
|
"""
|
|
return self._grouping
|
|
|
|
def get_having(self) -> MessageType:
|
|
"""Returns the having expression.
|
|
|
|
Returns:
|
|
object: The having expression.
|
|
"""
|
|
return self._having
|
|
|
|
def get_limit_row_count(self) -> int:
|
|
"""Returns the limit row count.
|
|
|
|
Returns:
|
|
int: The limit row count.
|
|
"""
|
|
return self._limit_row_count
|
|
|
|
def get_limit_offset(self) -> int:
|
|
"""Returns the limit offset.
|
|
|
|
Returns:
|
|
int: The limit offset.
|
|
"""
|
|
return self._limit_offset
|
|
|
|
def get_where_expr(self) -> MessageType:
|
|
"""Returns the where expression.
|
|
|
|
Returns:
|
|
object: The where expression.
|
|
"""
|
|
return self._where_expr
|
|
|
|
def get_projection_expr(
|
|
self,
|
|
) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
|
|
"""Returns the projection expression.
|
|
|
|
Returns:
|
|
object: The projection expression.
|
|
"""
|
|
return self._projection_expr
|
|
|
|
def get_sort_expr(
|
|
self,
|
|
) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
|
|
"""Returns the sort expression.
|
|
|
|
Returns:
|
|
object: The sort expression.
|
|
"""
|
|
return self._sort_expr
|
|
|
|
@deprecated("8.0.12")
|
|
def where(self, condition: str) -> FilterableStatement:
|
|
"""Sets the search condition to filter.
|
|
|
|
Args:
|
|
condition (str): Sets the search condition to filter documents or
|
|
records.
|
|
|
|
Returns:
|
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
|
|
.. deprecated:: 8.0.12
|
|
"""
|
|
return self._set_where(condition)
|
|
|
|
@deprecated("8.0.12")
|
|
def sort(self, *clauses: str) -> FilterableStatement:
|
|
"""Sets the sorting criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the sort criteria.
|
|
|
|
Returns:
|
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
|
|
.. deprecated:: 8.0.12
|
|
"""
|
|
return self._sort(*clauses)
|
|
|
|
def limit(
|
|
self, row_count: int, offset: Optional[int] = None
|
|
) -> FilterableStatement:
|
|
"""Sets the maximum number of items to be returned.
|
|
|
|
Args:
|
|
row_count (int): The maximum number of items.
|
|
|
|
Returns:
|
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
|
|
Raises:
|
|
ValueError: If ``row_count`` is not a positive integer.
|
|
|
|
.. versionchanged:: 8.0.12
|
|
The usage of ``offset`` was deprecated.
|
|
"""
|
|
if not isinstance(row_count, int) or row_count < 0:
|
|
raise ValueError("The 'row_count' value must be a positive integer")
|
|
if not self.has_limit:
|
|
self._changed = bool(self._exec_counter == 0)
|
|
self._deallocate_prepare_execute = bool(not self._exec_counter == 0)
|
|
|
|
self._limit_row_count = row_count
|
|
self.has_limit = True
|
|
if offset:
|
|
self.offset(offset)
|
|
warnings.warn(
|
|
"'limit(row_count, offset)' is deprecated, please "
|
|
"use 'offset(offset)' to set the number of items to "
|
|
"skip",
|
|
category=DeprecationWarning,
|
|
)
|
|
return self
|
|
|
|
def offset(self, offset: int) -> FilterableStatement:
|
|
"""Sets the number of items to skip.
|
|
|
|
Args:
|
|
offset (int): The number of items to skip.
|
|
|
|
Returns:
|
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
|
|
Raises:
|
|
ValueError: If ``offset`` is not a positive integer.
|
|
|
|
.. versionadded:: 8.0.12
|
|
"""
|
|
if not isinstance(offset, int) or offset < 0:
|
|
raise ValueError("The 'offset' value must be a positive integer")
|
|
self._limit_offset = offset
|
|
return self
|
|
|
|
def bind(self, *args: Any) -> FilterableStatement:
|
|
"""Binds value(s) to a specific placeholder(s).
|
|
|
|
Args:
|
|
*args: The name of the placeholder and the value to bind.
|
|
A :class:`mysqlx.DbDoc` object or a JSON string
|
|
representation can be used.
|
|
|
|
Returns:
|
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
|
|
Raises:
|
|
ProgrammingError: If the number of arguments is invalid.
|
|
"""
|
|
self.has_bindings = True
|
|
count = len(args)
|
|
if count == 1:
|
|
self._bind_single(args[0])
|
|
elif count == 2:
|
|
self._bindings[args[0]] = args[1]
|
|
else:
|
|
raise ProgrammingError("Invalid number of arguments to bind")
|
|
return self
|
|
|
|
def execute(self) -> Any:
|
|
"""Execute the statement.
|
|
|
|
Raises:
|
|
NotImplementedError: This method must be implemented.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class SqlStatement(Statement):
|
|
"""A statement for SQL execution.
|
|
|
|
Args:
|
|
connection (mysqlx.connection.Connection): Connection object.
|
|
sql (string): The sql statement to be executed.
|
|
"""
|
|
|
|
def __init__(self, connection: ConnectionType, sql: str) -> None:
|
|
super().__init__(target=None, doc_based=False)
|
|
self._connection: ConnectionType = connection
|
|
self._sql: str = sql
|
|
self._binding_map: Optional[Dict[str, Any]] = None
|
|
self._bindings: Union[List, Tuple] = []
|
|
self.has_bindings: bool = False
|
|
self.has_limit: bool = False
|
|
|
|
@property
|
|
def sql(self) -> str:
|
|
"""string: The SQL text statement."""
|
|
return self._sql
|
|
|
|
def get_binding_map(self) -> Dict[str, Any]:
|
|
"""Returns the binding map dictionary.
|
|
|
|
Returns:
|
|
dict: The binding map dictionary.
|
|
"""
|
|
return self._binding_map
|
|
|
|
def get_bindings(self) -> Union[Tuple, List]:
|
|
"""Returns the bindings list.
|
|
|
|
Returns:
|
|
`list`: The bindings list.
|
|
"""
|
|
return self._bindings
|
|
|
|
def bind(self, *args: Any) -> SqlStatement:
|
|
"""Binds value(s) to a specific placeholder(s).
|
|
|
|
Args:
|
|
*args: The value(s) to bind.
|
|
|
|
Returns:
|
|
mysqlx.SqlStatement: SqlStatement object.
|
|
"""
|
|
if len(args) == 0:
|
|
raise ProgrammingError("Invalid number of arguments to bind")
|
|
self.has_bindings = True
|
|
bindings = flexible_params(*args)
|
|
if isinstance(bindings, (list, tuple)):
|
|
self._bindings = bindings
|
|
else:
|
|
self._bindings.append(bindings)
|
|
return self
|
|
|
|
def execute(self) -> SqlResult:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.SqlResult: SqlResult object.
|
|
"""
|
|
return self._connection.send_sql(self)
|
|
|
|
|
|
class WriteStatement(Statement):
|
|
"""Provide common write operation attributes."""
|
|
|
|
def __init__(self, target: DatabaseTargetType, doc_based: bool) -> None:
|
|
super().__init__(target, doc_based)
|
|
self._values: List[
|
|
Union[
|
|
int,
|
|
str,
|
|
DbDoc,
|
|
Dict[str, Any],
|
|
List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]],
|
|
]
|
|
] = []
|
|
|
|
def get_values(
|
|
self,
|
|
) -> List[
|
|
Union[
|
|
int,
|
|
str,
|
|
DbDoc,
|
|
Dict[str, Any],
|
|
List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]],
|
|
]
|
|
]:
|
|
"""Returns the list of values.
|
|
|
|
Returns:
|
|
`list`: The list of values.
|
|
"""
|
|
return self._values
|
|
|
|
def execute(self) -> Any:
|
|
"""Execute the statement.
|
|
|
|
Raises:
|
|
NotImplementedError: This method must be implemented.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class AddStatement(WriteStatement):
|
|
"""A statement for document addition on a collection.
|
|
|
|
Args:
|
|
collection (mysqlx.Collection): The Collection object.
|
|
"""
|
|
|
|
def __init__(self, collection: DatabaseTargetType) -> None:
|
|
super().__init__(collection, True)
|
|
self._upsert: bool = False
|
|
self.ids: List = []
|
|
|
|
def is_upsert(self) -> bool:
|
|
"""Returns `True` if it's an upsert.
|
|
|
|
Returns:
|
|
bool: `True` if it's an upsert.
|
|
"""
|
|
return self._upsert
|
|
|
|
def upsert(self, value: bool = True) -> AddStatement:
|
|
"""Sets the upset flag to the boolean of the value provided.
|
|
Setting of this flag allows updating of the matched rows/documents
|
|
with the provided value.
|
|
|
|
Args:
|
|
value (optional[bool]): Set or unset the upsert flag.
|
|
"""
|
|
self._upsert = value
|
|
return self
|
|
|
|
def add(self, *values: DbDoc) -> AddStatement:
|
|
"""Adds a list of documents into a collection.
|
|
|
|
Args:
|
|
*values: The documents to be added into the collection.
|
|
|
|
Returns:
|
|
mysqlx.AddStatement: AddStatement object.
|
|
"""
|
|
for val in flexible_params(*values):
|
|
if isinstance(val, DbDoc):
|
|
self._values.append(val)
|
|
else:
|
|
self._values.append(DbDoc(val))
|
|
return self
|
|
|
|
def execute(self) -> Result:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object.
|
|
"""
|
|
if len(self._values) == 0:
|
|
return Result()
|
|
|
|
return self._connection.send_insert(self)
|
|
|
|
|
|
class UpdateSpec:
|
|
"""Update specification class implementation.
|
|
|
|
Args:
|
|
update_type (int): The update type.
|
|
source (str): The source.
|
|
value (Optional[str]): The value.
|
|
|
|
Raises:
|
|
ProgrammingError: If `source` is invalid.
|
|
"""
|
|
|
|
def __init__(self, update_type: int, source: str, value: Any = None) -> None:
|
|
if update_type == mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"):
|
|
self._table_set(source, value)
|
|
else:
|
|
self.update_type: int = update_type
|
|
try:
|
|
self.source: Any = ExprParser(source, False).document_field().identifier
|
|
except ValueError as err:
|
|
raise ProgrammingError(f"{err}") from err
|
|
self.value: Any = value
|
|
|
|
def _table_set(self, source: str, value: Any) -> None:
|
|
"""Table set.
|
|
|
|
Args:
|
|
source (str): The source.
|
|
value (str): The value.
|
|
"""
|
|
self.update_type = mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET")
|
|
self.source = ExprParser(source, True).parse_table_update_field()
|
|
self.value = value
|
|
|
|
|
|
class ModifyStatement(FilterableStatement):
|
|
"""A statement for document update operations on a Collection.
|
|
|
|
Args:
|
|
collection (mysqlx.Collection): The Collection object.
|
|
condition (str): Sets the search condition to identify the documents
|
|
to be modified.
|
|
|
|
.. versionchanged:: 8.0.12
|
|
The ``condition`` parameter is now mandatory.
|
|
"""
|
|
|
|
def __init__(self, collection: DatabaseTargetType, condition: str) -> None:
|
|
super().__init__(target=collection, condition=condition)
|
|
self._update_ops: Dict[str, Any] = {}
|
|
|
|
def sort(self, *clauses: str) -> ModifyStatement:
|
|
"""Sets the sorting criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the sort criteria.
|
|
|
|
Returns:
|
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
"""
|
|
return self._sort(*clauses)
|
|
|
|
def get_update_ops(self) -> Dict[str, Any]:
|
|
"""Returns the list of update operations.
|
|
|
|
Returns:
|
|
`list`: The list of update operations.
|
|
"""
|
|
return self._update_ops
|
|
|
|
def set(self, doc_path: str, value: Any) -> ModifyStatement:
|
|
"""Sets or updates attributes on documents in a collection.
|
|
|
|
Args:
|
|
doc_path (string): The document path of the item to be set.
|
|
value (string): The value to be set on the specified attribute.
|
|
|
|
Returns:
|
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
"""
|
|
self._update_ops[doc_path] = UpdateSpec(
|
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_SET"),
|
|
doc_path,
|
|
value,
|
|
)
|
|
self._changed = True
|
|
return self
|
|
|
|
@deprecated("8.0.12")
|
|
def change(self, doc_path: str, value: Any) -> ModifyStatement:
|
|
"""Add an update to the statement setting the field, if it exists at
|
|
the document path, to the given value.
|
|
|
|
Args:
|
|
doc_path (string): The document path of the item to be set.
|
|
value (object): The value to be set on the specified attribute.
|
|
|
|
Returns:
|
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
|
|
.. deprecated:: 8.0.12
|
|
"""
|
|
self._update_ops[doc_path] = UpdateSpec(
|
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REPLACE"),
|
|
doc_path,
|
|
value,
|
|
)
|
|
self._changed = True
|
|
return self
|
|
|
|
def unset(self, *doc_paths: str) -> ModifyStatement:
|
|
"""Removes attributes from documents in a collection.
|
|
|
|
Args:
|
|
doc_paths (list): The list of document paths of the attributes to be
|
|
removed.
|
|
|
|
Returns:
|
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
"""
|
|
for item in flexible_params(*doc_paths):
|
|
self._update_ops[item] = UpdateSpec(
|
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REMOVE"),
|
|
item,
|
|
)
|
|
self._changed = True
|
|
return self
|
|
|
|
def array_insert(self, field: str, value: Any) -> ModifyStatement:
|
|
"""Insert a value into the specified array in documents of a
|
|
collection.
|
|
|
|
Args:
|
|
field (string): A document path that identifies the array attribute
|
|
and position where the value will be inserted.
|
|
value (object): The value to be inserted.
|
|
|
|
Returns:
|
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
"""
|
|
self._update_ops[field] = UpdateSpec(
|
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_INSERT"),
|
|
field,
|
|
value,
|
|
)
|
|
self._changed = True
|
|
return self
|
|
|
|
def array_append(self, doc_path: str, value: Any) -> ModifyStatement:
|
|
"""Inserts a value into a specific position in an array attribute in
|
|
documents of a collection.
|
|
|
|
Args:
|
|
doc_path (string): A document path that identifies the array
|
|
attribute and position where the value will be
|
|
inserted.
|
|
value (object): The value to be inserted.
|
|
|
|
Returns:
|
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
"""
|
|
self._update_ops[doc_path] = UpdateSpec(
|
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_APPEND"),
|
|
doc_path,
|
|
value,
|
|
)
|
|
self._changed = True
|
|
return self
|
|
|
|
def patch(self, doc: Union[Dict, DbDoc, ExprParser, str]) -> ModifyStatement:
|
|
"""Takes a :class:`mysqlx.DbDoc`, string JSON format or a dict with the
|
|
changes and applies it on all matching documents.
|
|
|
|
Args:
|
|
doc (object): A generic document (DbDoc), string in JSON format or
|
|
dict, with the changes to apply to the matching
|
|
documents.
|
|
|
|
Returns:
|
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
"""
|
|
if doc is None:
|
|
doc = ""
|
|
if not isinstance(doc, (ExprParser, dict, DbDoc, str)):
|
|
raise ProgrammingError(
|
|
"Invalid data for update operation on document collection table"
|
|
)
|
|
self._update_ops["patch"] = UpdateSpec(
|
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.MERGE_PATCH"),
|
|
"$",
|
|
doc.expr() if isinstance(doc, ExprParser) else doc,
|
|
)
|
|
self._changed = True
|
|
return self
|
|
|
|
def execute(self) -> Result:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object.
|
|
|
|
Raises:
|
|
ProgrammingError: If condition was not set.
|
|
"""
|
|
if not self.has_where:
|
|
raise ProgrammingError("No condition was found for modify")
|
|
return self._connection.send_update(self)
|
|
|
|
|
|
class ReadStatement(FilterableStatement):
|
|
"""Provide base functionality for Read operations
|
|
|
|
Args:
|
|
target (object): The target database object, it can be
|
|
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
|
|
doc_based (Optional[bool]): `True` if it is document based
|
|
(default: `True`).
|
|
condition (Optional[str]): Sets the search condition to filter
|
|
documents or records.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
target: DatabaseTargetType,
|
|
doc_based: bool = True,
|
|
condition: Optional[str] = None,
|
|
) -> None:
|
|
super().__init__(target, doc_based, condition)
|
|
self._lock_exclusive: bool = False
|
|
self._lock_shared: bool = False
|
|
self._lock_contention: LockContention = LockContention.DEFAULT
|
|
|
|
@property
|
|
def lock_contention(self) -> LockContention:
|
|
""":class:`mysqlx.LockContention`: The lock contention value."""
|
|
return self._lock_contention
|
|
|
|
def _set_lock_contention(self, lock_contention: LockContention) -> None:
|
|
"""Set the lock contention.
|
|
|
|
Args:
|
|
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
|
|
|
|
Raises:
|
|
ProgrammingError: If is an invalid lock contention value.
|
|
"""
|
|
try:
|
|
# Check if is a valid lock contention value
|
|
_ = LockContention(lock_contention.value)
|
|
except ValueError as err:
|
|
raise ProgrammingError(
|
|
"Invalid lock contention mode. Use 'NOWAIT' or 'SKIP_LOCKED'"
|
|
) from err
|
|
self._lock_contention = lock_contention
|
|
|
|
def is_lock_exclusive(self) -> bool:
|
|
"""Returns `True` if is `EXCLUSIVE LOCK`.
|
|
|
|
Returns:
|
|
bool: `True` if is `EXCLUSIVE LOCK`.
|
|
"""
|
|
return self._lock_exclusive
|
|
|
|
def is_lock_shared(self) -> bool:
|
|
"""Returns `True` if is `SHARED LOCK`.
|
|
|
|
Returns:
|
|
bool: `True` if is `SHARED LOCK`.
|
|
"""
|
|
return self._lock_shared
|
|
|
|
def lock_shared(
|
|
self, lock_contention: LockContention = LockContention.DEFAULT
|
|
) -> ReadStatement:
|
|
"""Execute a read operation with `SHARED LOCK`. Only one lock can be
|
|
active at a time.
|
|
|
|
Args:
|
|
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
|
|
"""
|
|
self._lock_exclusive = False
|
|
self._lock_shared = True
|
|
self._set_lock_contention(lock_contention)
|
|
return self
|
|
|
|
def lock_exclusive(
|
|
self, lock_contention: LockContention = LockContention.DEFAULT
|
|
) -> ReadStatement:
|
|
"""Execute a read operation with `EXCLUSIVE LOCK`. Only one lock can be
|
|
active at a time.
|
|
|
|
Args:
|
|
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
|
|
"""
|
|
self._lock_exclusive = True
|
|
self._lock_shared = False
|
|
self._set_lock_contention(lock_contention)
|
|
return self
|
|
|
|
def group_by(self, *fields: str) -> ReadStatement:
|
|
"""Sets a grouping criteria for the resultset.
|
|
|
|
Args:
|
|
*fields: The string expressions identifying the grouping criteria.
|
|
|
|
Returns:
|
|
mysqlx.ReadStatement: ReadStatement object.
|
|
"""
|
|
self._set_group_by(*fields)
|
|
return self
|
|
|
|
def having(self, condition: str) -> ReadStatement:
|
|
"""Sets a condition for records to be considered in agregate function
|
|
operations.
|
|
|
|
Args:
|
|
condition (string): A condition on the agregate functions used on
|
|
the grouping criteria.
|
|
|
|
Returns:
|
|
mysqlx.ReadStatement: ReadStatement object.
|
|
"""
|
|
self._set_having(condition)
|
|
return self
|
|
|
|
def execute(self) -> Union[DocResult, RowResult]:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object.
|
|
"""
|
|
return self._connection.send_find(self)
|
|
|
|
|
|
class FindStatement(ReadStatement):
|
|
"""A statement document selection on a Collection.
|
|
|
|
Args:
|
|
collection (mysqlx.Collection): The Collection object.
|
|
condition (Optional[str]): An optional expression to identify the
|
|
documents to be retrieved. If not specified
|
|
all the documents will be included on the
|
|
result unless a limit is set.
|
|
"""
|
|
|
|
def __init__(
|
|
self, collection: DatabaseTargetType, condition: Optional[str] = None
|
|
) -> None:
|
|
super().__init__(collection, True, condition)
|
|
|
|
def fields(self, *fields: str) -> FindStatement:
|
|
"""Sets a document field filter.
|
|
|
|
Args:
|
|
*fields: The string expressions identifying the fields to be
|
|
extracted.
|
|
|
|
Returns:
|
|
mysqlx.FindStatement: FindStatement object.
|
|
"""
|
|
return self._set_projection(*fields)
|
|
|
|
def sort(self, *clauses: str) -> FindStatement:
|
|
"""Sets the sorting criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the sort criteria.
|
|
|
|
Returns:
|
|
mysqlx.FindStatement: FindStatement object.
|
|
"""
|
|
return self._sort(*clauses)
|
|
|
|
|
|
class SelectStatement(ReadStatement):
|
|
"""A statement for record retrieval operations on a Table.
|
|
|
|
Args:
|
|
table (mysqlx.Table): The Table object.
|
|
*fields: The fields to be retrieved.
|
|
"""
|
|
|
|
def __init__(self, table: DatabaseTargetType, *fields: str) -> None:
|
|
super().__init__(table, False)
|
|
self._set_projection(*fields)
|
|
|
|
def where(self, condition: str) -> SelectStatement:
|
|
"""Sets the search condition to filter.
|
|
|
|
Args:
|
|
condition (str): Sets the search condition to filter records.
|
|
|
|
Returns:
|
|
mysqlx.SelectStatement: SelectStatement object.
|
|
"""
|
|
return self._set_where(condition)
|
|
|
|
def order_by(self, *clauses: str) -> SelectStatement:
|
|
"""Sets the order by criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the order by criteria.
|
|
|
|
Returns:
|
|
mysqlx.SelectStatement: SelectStatement object.
|
|
"""
|
|
return self._sort(*clauses)
|
|
|
|
def get_sql(self) -> str:
|
|
"""Returns the generated SQL.
|
|
|
|
Returns:
|
|
str: The generated SQL.
|
|
"""
|
|
where = f" WHERE {self._where_str}" if self.has_where else ""
|
|
group_by = f" GROUP BY {self._grouping_str}" if self.has_group_by else ""
|
|
having = f" HAVING {self._having}" if self.has_having else ""
|
|
order_by = f" ORDER BY {self._sort_str}" if self.has_sort else ""
|
|
limit = (
|
|
f" LIMIT {self._limit_row_count} OFFSET {self._limit_offset}"
|
|
if self.has_limit
|
|
else ""
|
|
)
|
|
stmt = (
|
|
f"SELECT {self._projection_str or '*'} "
|
|
f"FROM {self.schema.name}.{self.target.name}"
|
|
f"{where}{group_by}{having}{order_by}{limit}"
|
|
)
|
|
return stmt
|
|
|
|
|
|
class InsertStatement(WriteStatement):
|
|
"""A statement for insert operations on Table.
|
|
|
|
Args:
|
|
table (mysqlx.Table): The Table object.
|
|
*fields: The fields to be inserted.
|
|
"""
|
|
|
|
def __init__(self, table: DatabaseTargetType, *fields: Any) -> None:
|
|
super().__init__(table, False)
|
|
self._fields: Union[List, Tuple] = flexible_params(*fields)
|
|
|
|
def values(self, *values: Any) -> InsertStatement:
|
|
"""Set the values to be inserted.
|
|
|
|
Args:
|
|
*values: The values of the columns to be inserted.
|
|
|
|
Returns:
|
|
mysqlx.InsertStatement: InsertStatement object.
|
|
"""
|
|
self._values.append(list(flexible_params(*values)))
|
|
return self
|
|
|
|
def execute(self) -> Result:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object.
|
|
"""
|
|
return self._connection.send_insert(self)
|
|
|
|
|
|
class UpdateStatement(FilterableStatement):
|
|
"""A statement for record update operations on a Table.
|
|
|
|
Args:
|
|
table (mysqlx.Table): The Table object.
|
|
|
|
.. versionchanged:: 8.0.12
|
|
The ``fields`` parameters were removed.
|
|
"""
|
|
|
|
def __init__(self, table: DatabaseTargetType) -> None:
|
|
super().__init__(target=table, doc_based=False)
|
|
self._update_ops: Dict[str, Any] = {}
|
|
|
|
def where(self, condition: str) -> UpdateStatement:
|
|
"""Sets the search condition to filter.
|
|
|
|
Args:
|
|
condition (str): Sets the search condition to filter records.
|
|
|
|
Returns:
|
|
mysqlx.UpdateStatement: UpdateStatement object.
|
|
"""
|
|
return self._set_where(condition)
|
|
|
|
def order_by(self, *clauses: str) -> UpdateStatement:
|
|
"""Sets the order by criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the order by criteria.
|
|
|
|
Returns:
|
|
mysqlx.UpdateStatement: UpdateStatement object.
|
|
"""
|
|
return self._sort(*clauses)
|
|
|
|
def get_update_ops(self) -> Dict[str, Any]:
|
|
"""Returns the list of update operations.
|
|
|
|
Returns:
|
|
`list`: The list of update operations.
|
|
"""
|
|
return self._update_ops
|
|
|
|
def set(self, field: str, value: Any) -> UpdateStatement:
|
|
"""Updates the column value on records in a table.
|
|
|
|
Args:
|
|
field (string): The column name to be updated.
|
|
value (object): The value to be set on the specified column.
|
|
|
|
Returns:
|
|
mysqlx.UpdateStatement: UpdateStatement object.
|
|
"""
|
|
self._update_ops[field] = UpdateSpec(
|
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"),
|
|
field,
|
|
value,
|
|
)
|
|
self._changed = True
|
|
return self
|
|
|
|
def execute(self) -> Result:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object
|
|
|
|
Raises:
|
|
ProgrammingError: If condition was not set.
|
|
"""
|
|
if not self.has_where:
|
|
raise ProgrammingError("No condition was found for update")
|
|
return self._connection.send_update(self)
|
|
|
|
|
|
class RemoveStatement(FilterableStatement):
|
|
"""A statement for document removal from a collection.
|
|
|
|
Args:
|
|
collection (mysqlx.Collection): The Collection object.
|
|
condition (str): Sets the search condition to identify the documents
|
|
to be removed.
|
|
|
|
.. versionchanged:: 8.0.12
|
|
The ``condition`` parameter was added.
|
|
"""
|
|
|
|
def __init__(self, collection: DatabaseTargetType, condition: str) -> None:
|
|
super().__init__(target=collection, condition=condition)
|
|
|
|
def sort(self, *clauses: str) -> RemoveStatement:
|
|
"""Sets the sorting criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the sort criteria.
|
|
|
|
Returns:
|
|
mysqlx.FindStatement: FindStatement object.
|
|
"""
|
|
return self._sort(*clauses)
|
|
|
|
def execute(self) -> Result:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object.
|
|
|
|
Raises:
|
|
ProgrammingError: If condition was not set.
|
|
"""
|
|
if not self.has_where:
|
|
raise ProgrammingError("No condition was found for remove")
|
|
return self._connection.send_delete(self)
|
|
|
|
|
|
class DeleteStatement(FilterableStatement):
|
|
"""A statement that drops a table.
|
|
|
|
Args:
|
|
table (mysqlx.Table): The Table object.
|
|
|
|
.. versionchanged:: 8.0.12
|
|
The ``condition`` parameter was removed.
|
|
"""
|
|
|
|
def __init__(self, table: DatabaseTargetType) -> None:
|
|
super().__init__(target=table, doc_based=False)
|
|
|
|
def where(self, condition: str) -> DeleteStatement:
|
|
"""Sets the search condition to filter.
|
|
|
|
Args:
|
|
condition (str): Sets the search condition to filter records.
|
|
|
|
Returns:
|
|
mysqlx.DeleteStatement: DeleteStatement object.
|
|
"""
|
|
return self._set_where(condition)
|
|
|
|
def order_by(self, *clauses: str) -> DeleteStatement:
|
|
"""Sets the order by criteria.
|
|
|
|
Args:
|
|
*clauses: The expression strings defining the order by criteria.
|
|
|
|
Returns:
|
|
mysqlx.DeleteStatement: DeleteStatement object.
|
|
"""
|
|
return self._sort(*clauses)
|
|
|
|
def execute(self) -> Result:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object.
|
|
|
|
Raises:
|
|
ProgrammingError: If condition was not set.
|
|
"""
|
|
if not self.has_where:
|
|
raise ProgrammingError("No condition was found for delete")
|
|
return self._connection.send_delete(self)
|
|
|
|
|
|
class CreateCollectionIndexStatement(Statement):
|
|
"""A statement that creates an index on a collection.
|
|
|
|
Args:
|
|
collection (mysqlx.Collection): Collection.
|
|
index_name (string): Index name.
|
|
index_desc (dict): A dictionary containing the fields members that
|
|
constraints the index to be created. It must have
|
|
the form as shown in the following::
|
|
|
|
{"fields": [{"field": member_path,
|
|
"type": member_type,
|
|
"required": member_required,
|
|
"collation": collation,
|
|
"options": options,
|
|
"srid": srid},
|
|
# {... more members,
|
|
# repeated as many times
|
|
# as needed}
|
|
],
|
|
"type": type}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
collection: DatabaseTargetType,
|
|
index_name: str,
|
|
index_desc: Dict[str, Any],
|
|
) -> None:
|
|
super().__init__(target=collection)
|
|
self._index_desc: Dict[str, Any] = copy.deepcopy(index_desc)
|
|
self._index_name: str = index_name
|
|
self._fields_desc: List[Dict[str, Any]] = self._index_desc.pop("fields", [])
|
|
|
|
def execute(self) -> Result:
|
|
"""Execute the statement.
|
|
|
|
Returns:
|
|
mysqlx.Result: Result object.
|
|
"""
|
|
# Validate index name is a valid identifier
|
|
if self._index_name is None:
|
|
raise ProgrammingError(ERR_INVALID_INDEX_NAME.format(self._index_name))
|
|
try:
|
|
parsed_ident = ExprParser(self._index_name).expr().get_message()
|
|
|
|
# The message is type dict when the Protobuf cext is used
|
|
if isinstance(parsed_ident, dict):
|
|
if parsed_ident["type"] != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"):
|
|
raise ProgrammingError(
|
|
ERR_INVALID_INDEX_NAME.format(self._index_name)
|
|
)
|
|
else:
|
|
if parsed_ident.type != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"):
|
|
raise ProgrammingError(
|
|
ERR_INVALID_INDEX_NAME.format(self._index_name)
|
|
)
|
|
|
|
except (ValueError, AttributeError) as err:
|
|
raise ProgrammingError(
|
|
ERR_INVALID_INDEX_NAME.format(self._index_name)
|
|
) from err
|
|
|
|
# Validate members that constraint the index
|
|
if not self._fields_desc:
|
|
raise ProgrammingError(
|
|
"Required member 'fields' not found in the given index "
|
|
f"description: {self._index_desc}"
|
|
)
|
|
|
|
if not isinstance(self._fields_desc, list):
|
|
raise ProgrammingError("Required member 'fields' must contain a list")
|
|
args: Dict[str, Any] = {}
|
|
args["name"] = self._index_name
|
|
args["collection"] = self._target.name
|
|
args["schema"] = self._target.schema.name
|
|
if "type" in self._index_desc:
|
|
args["type"] = self._index_desc.pop("type")
|
|
else:
|
|
args["type"] = "INDEX"
|
|
args["unique"] = self._index_desc.pop("unique", False)
|
|
# Currently unique indexes are not supported:
|
|
if args["unique"]:
|
|
raise NotSupportedError("Unique indexes are not supported.")
|
|
args["constraint"] = []
|
|
|
|
if self._index_desc:
|
|
raise ProgrammingError(f"Unidentified fields: {self._index_desc}")
|
|
|
|
try:
|
|
for field_desc in self._fields_desc:
|
|
constraint = {}
|
|
constraint["member"] = field_desc.pop("field")
|
|
constraint["type"] = field_desc.pop("type")
|
|
constraint["required"] = field_desc.pop("required", False)
|
|
constraint["array"] = field_desc.pop("array", False)
|
|
if not isinstance(constraint["required"], bool):
|
|
raise TypeError("Field member 'required' must be Boolean")
|
|
if not isinstance(constraint["array"], bool):
|
|
raise TypeError("Field member 'array' must be Boolean")
|
|
if args["type"].upper() == "SPATIAL" and not constraint["required"]:
|
|
raise ProgrammingError(
|
|
"Field member 'required' must be set to 'True' when "
|
|
"index type is set to 'SPATIAL'"
|
|
)
|
|
if args["type"].upper() == "INDEX" and constraint["type"] == "GEOJSON":
|
|
raise ProgrammingError(
|
|
"Index 'type' must be set to 'SPATIAL' when field "
|
|
"type is set to 'GEOJSON'"
|
|
)
|
|
if "collation" in field_desc:
|
|
if not constraint["type"].upper().startswith("TEXT"):
|
|
raise ProgrammingError(
|
|
"The 'collation' member can only be used when "
|
|
"field type is set to "
|
|
f"'{constraint['type'].upper()}'"
|
|
)
|
|
constraint["collation"] = field_desc.pop("collation")
|
|
# "options" and "srid" fields in IndexField can be
|
|
# present only if "type" is set to "GEOJSON"
|
|
if "options" in field_desc:
|
|
if constraint["type"].upper() != "GEOJSON":
|
|
raise ProgrammingError(
|
|
"The 'options' member can only be used when "
|
|
"index type is set to 'GEOJSON'"
|
|
)
|
|
constraint["options"] = field_desc.pop("options")
|
|
if "srid" in field_desc:
|
|
if constraint["type"].upper() != "GEOJSON":
|
|
raise ProgrammingError(
|
|
"The 'srid' member can only be used when index "
|
|
"type is set to 'GEOJSON'"
|
|
)
|
|
constraint["srid"] = field_desc.pop("srid")
|
|
args["constraint"].append(constraint)
|
|
except KeyError as err:
|
|
raise ProgrammingError(
|
|
f"Required inner member {err} not found in constraint: {field_desc}"
|
|
) from err
|
|
|
|
for field_desc in self._fields_desc:
|
|
if field_desc:
|
|
raise ProgrammingError(f"Unidentified inner fields: {field_desc}")
|
|
|
|
return self._connection.execute_nonquery(
|
|
"mysqlx", "create_collection_index", True, args
|
|
)
|