Fixed an issue that prevented the password reset tokens from working. Added email templates for password reset success and new account creation. Added more dynamic email template support.
1214 lines
42 KiB
Python
1214 lines
42 KiB
Python
# Copyright (c) 2016, 2023, Oracle and/or its affiliates.
|
|
#
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License, version 2.0, as
|
|
# published by the Free Software Foundation.
|
|
#
|
|
# This program is also distributed with certain software (including
|
|
# but not limited to OpenSSL) that is licensed under separate terms,
|
|
# as designated in a particular file or component or in included license
|
|
# documentation. The authors of MySQL hereby grant you an
|
|
# additional permission to link the program and your derivative works
|
|
# with the separately licensed software that they have included with
|
|
# MySQL.
|
|
#
|
|
# Without limiting anything contained in the foregoing, this file,
|
|
# which is part of MySQL Connector/Python, is also subject to the
|
|
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
|
# http://oss.oracle.com/licenses/universal-foss-exception.
|
|
#
|
|
# This program is distributed in the hope that it will be useful, but
|
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
|
# See the GNU General Public License, version 2.0, for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
|
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
|
|
|
"""Implementation of the X protocol for MySQL servers."""
|
|
|
|
import struct
|
|
import zlib
|
|
|
|
from io import BytesIO
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
try:
|
|
import lz4.frame
|
|
|
|
HAVE_LZ4 = True
|
|
except ImportError:
|
|
HAVE_LZ4 = False
|
|
|
|
try:
|
|
import zstandard as zstd
|
|
|
|
HAVE_ZSTD = True
|
|
except ImportError:
|
|
HAVE_ZSTD = False
|
|
|
|
from .errors import (
|
|
InterfaceError,
|
|
NotSupportedError,
|
|
OperationalError,
|
|
ProgrammingError,
|
|
)
|
|
from .expr import (
|
|
ExprParser,
|
|
build_bool_scalar,
|
|
build_expr,
|
|
build_int_scalar,
|
|
build_scalar,
|
|
build_unsigned_int_scalar,
|
|
)
|
|
from .helpers import encode_to_bytes, get_item_or_attr
|
|
from .logger import logger
|
|
from .protobuf import CRUD_PREPARE_MAPPING, SERVER_MESSAGES, Message, mysqlxpb_enum
|
|
from .result import Column
|
|
from .statement import (
|
|
AddStatement,
|
|
DeleteStatement,
|
|
FilterableStatement,
|
|
FindStatement,
|
|
InsertStatement,
|
|
ModifyStatement,
|
|
ReadStatement,
|
|
RemoveStatement,
|
|
SqlStatement,
|
|
UpdateStatement,
|
|
)
|
|
from .types import (
|
|
ColumnType,
|
|
MessageType,
|
|
ProtobufMessageCextType,
|
|
ProtobufMessageType,
|
|
ResultBaseType,
|
|
SocketType,
|
|
StatementType,
|
|
StrOrBytes,
|
|
)
|
|
|
|
_COMPRESSION_THRESHOLD = 1000
|
|
|
|
|
|
class Compressor:
|
|
"""Implements compression/decompression using `zstd_stream`, `lz4_message`
|
|
and `deflate_stream` algorithms.
|
|
|
|
Args:
|
|
algorithm (str): Compression algorithm.
|
|
|
|
.. versionadded:: 8.0.21
|
|
|
|
"""
|
|
|
|
def __init__(self, algorithm: str) -> None:
|
|
self._algorithm: str = algorithm
|
|
self._compressobj: Any = None
|
|
self._decompressobj: Any = None
|
|
|
|
if algorithm == "zstd_stream":
|
|
self._compressobj = zstd.ZstdCompressor()
|
|
self._decompressobj = zstd.ZstdDecompressor()
|
|
elif algorithm == "deflate_stream":
|
|
self._compressobj = zlib.compressobj()
|
|
self._decompressobj = zlib.decompressobj()
|
|
|
|
def compress(self, data: StrOrBytes) -> bytes:
|
|
"""Compresses data and returns it.
|
|
|
|
Args:
|
|
data (str, bytes or buffer object): Data to be compressed.
|
|
|
|
Returns:
|
|
bytes: Compressed data.
|
|
"""
|
|
if self._algorithm == "zstd_stream":
|
|
return self._compressobj.compress(data)
|
|
if self._algorithm == "lz4_message":
|
|
with lz4.frame.LZ4FrameCompressor() as compressor:
|
|
compressed = compressor.begin()
|
|
compressed += compressor.compress(data)
|
|
compressed += compressor.flush()
|
|
return compressed
|
|
|
|
# Using 'deflate_stream' algorithm
|
|
compressed = self._compressobj.compress(data)
|
|
compressed += self._compressobj.flush(zlib.Z_SYNC_FLUSH)
|
|
return compressed
|
|
|
|
def decompress(self, data: StrOrBytes) -> bytes:
|
|
"""Decompresses a frame of data and returns it as a string of bytes.
|
|
|
|
Args:
|
|
data (str, bytes or buffer object): Data to be compressed.
|
|
|
|
Returns:
|
|
bytes: Decompresssed data.
|
|
"""
|
|
if self._algorithm == "zstd_stream":
|
|
return self._decompressobj.decompress(data)
|
|
if self._algorithm == "lz4_message":
|
|
with lz4.frame.LZ4FrameDecompressor() as decompressor:
|
|
decompressed = decompressor.decompress(data)
|
|
return decompressed
|
|
|
|
# Using 'deflate' algorithm
|
|
decompressed = self._decompressobj.decompress(data)
|
|
decompressed += self._decompressobj.flush(zlib.Z_SYNC_FLUSH)
|
|
return decompressed
|
|
|
|
|
|
class MessageReader:
|
|
"""Implements a Message Reader.
|
|
|
|
Args:
|
|
socket_stream (mysqlx.connection.SocketStream): `SocketStream` object.
|
|
|
|
.. versionadded:: 8.0.21
|
|
"""
|
|
|
|
def __init__(self, socket_stream: SocketType) -> None:
|
|
self._stream: SocketType = socket_stream
|
|
self._compressor: Optional[Compressor] = None
|
|
self._msg: MessageType = None
|
|
self._msg_queue: List[Message] = []
|
|
|
|
def _read_message(self) -> MessageType:
|
|
"""Reads X Protocol messages from the stream and returns a
|
|
:class:`mysqlx.protobuf.Message` object.
|
|
|
|
Raises:
|
|
:class:`mysqlx.ProgrammingError`: If e connected server does not
|
|
have the MySQL X protocol plugin
|
|
enabled.
|
|
|
|
Returns:
|
|
mysqlx.protobuf.Message: MySQL X Protobuf Message.
|
|
"""
|
|
if self._msg_queue:
|
|
return self._msg_queue.pop(0)
|
|
|
|
frame_size, frame_type = struct.unpack("<LB", self._stream.read(5))
|
|
|
|
if frame_type == 10:
|
|
raise ProgrammingError(
|
|
"The connected server does not have the "
|
|
"MySQL X protocol plugin enabled or "
|
|
"protocol mismatch"
|
|
)
|
|
|
|
frame_payload = self._stream.read(frame_size - 1)
|
|
if frame_type not in SERVER_MESSAGES:
|
|
raise ValueError(f"Unknown message type: {frame_type}")
|
|
|
|
# Do not parse empty notices, Message requires a type in payload
|
|
if frame_type == 11 and frame_payload == b"":
|
|
return self._read_message()
|
|
|
|
frame_msg = Message.from_server_message(frame_type, frame_payload)
|
|
|
|
if frame_type == 19: # Mysqlx.ServerMessages.Type.COMPRESSION
|
|
uncompressed_size = frame_msg["uncompressed_size"]
|
|
stream = BytesIO(self._compressor.decompress(frame_msg["payload"]))
|
|
bytes_processed = 0
|
|
while bytes_processed < uncompressed_size:
|
|
payload_size, msg_type = struct.unpack("<LB", stream.read(5))
|
|
payload = stream.read(payload_size - 1)
|
|
self._msg_queue.append(Message.from_server_message(msg_type, payload))
|
|
bytes_processed += payload_size + 4
|
|
return self._msg_queue.pop(0) if self._msg_queue else None
|
|
|
|
return frame_msg
|
|
|
|
def read_message(self) -> MessageType:
|
|
"""Read message.
|
|
|
|
Returns:
|
|
mysqlx.protobuf.Message: MySQL X Protobuf Message.
|
|
"""
|
|
if self._msg is not None:
|
|
msg = self._msg
|
|
self._msg = None
|
|
return msg
|
|
return self._read_message()
|
|
|
|
def push_message(self, msg: MessageType) -> None:
|
|
"""Push message.
|
|
|
|
Args:
|
|
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
|
|
|
|
Raises:
|
|
:class:`mysqlx.OperationalError`: If message push slot is full.
|
|
"""
|
|
if self._msg is not None:
|
|
raise OperationalError("Message push slot is full")
|
|
self._msg = msg
|
|
|
|
def set_compression(self, algorithm: str) -> None:
|
|
"""Creates a :class:`mysqlx.protocol.Compressor` object based on the
|
|
compression algorithm.
|
|
|
|
Args:
|
|
algorithm (str): Compression algorithm.
|
|
|
|
.. versionadded:: 8.0.21
|
|
|
|
"""
|
|
self._compressor = Compressor(algorithm) if algorithm else None
|
|
|
|
|
|
class MessageWriter:
|
|
"""Implements a Message Writer.
|
|
|
|
Args:
|
|
socket_stream (mysqlx.connection.SocketStream): `SocketStream` object.
|
|
|
|
.. versionadded:: 8.0.21
|
|
|
|
"""
|
|
|
|
def __init__(self, socket_stream: SocketType) -> None:
|
|
self._stream: SocketType = socket_stream
|
|
self._compressor: Optional[Compressor] = None
|
|
|
|
def write_message(self, msg_type: int, msg: MessageType) -> None:
|
|
"""Write message.
|
|
|
|
Args:
|
|
msg_type (int): The message type.
|
|
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
|
|
"""
|
|
msg_size = msg.byte_size(msg)
|
|
if self._compressor and msg_size > _COMPRESSION_THRESHOLD:
|
|
msg_str = encode_to_bytes(msg.serialize_to_string())
|
|
header = struct.pack("<LB", msg_size + 1, msg_type)
|
|
compressed = self._compressor.compress(b"".join([header, msg_str]))
|
|
|
|
msg_first_fields = Message("Mysqlx.Connection.Compression")
|
|
msg_first_fields["client_messages"] = msg_type
|
|
msg_first_fields["uncompressed_size"] = msg_size + 5
|
|
|
|
msg_payload = Message("Mysqlx.Connection.Compression")
|
|
msg_payload["payload"] = compressed
|
|
|
|
output = b"".join(
|
|
[
|
|
encode_to_bytes(msg_first_fields.serialize_partial_to_string())[
|
|
:-2
|
|
],
|
|
encode_to_bytes(msg_payload.serialize_partial_to_string()),
|
|
]
|
|
)
|
|
|
|
msg_comp_id = mysqlxpb_enum("Mysqlx.ClientMessages.Type.COMPRESSION")
|
|
header = struct.pack("<LB", len(output) + 1, msg_comp_id)
|
|
self._stream.sendall(b"".join([header, output]))
|
|
else:
|
|
msg_str = encode_to_bytes(msg.serialize_to_string())
|
|
header = struct.pack("<LB", msg_size + 1, msg_type)
|
|
self._stream.sendall(b"".join([header, msg_str]))
|
|
|
|
def set_compression(self, algorithm: str) -> None:
|
|
"""Creates a :class:`mysqlx.protocol.Compressor` object based on the
|
|
compression algorithm.
|
|
|
|
Args:
|
|
algorithm (str): Compression algorithm.
|
|
"""
|
|
self._compressor = Compressor(algorithm) if algorithm else None
|
|
|
|
|
|
class Protocol:
|
|
"""Implements the MySQL X Protocol.
|
|
|
|
Args:
|
|
read (mysqlx.protocol.MessageReader): A Message Reader object.
|
|
writer (mysqlx.protocol.MessageWriter): A Message Writer object.
|
|
|
|
.. versionchanged:: 8.0.21
|
|
"""
|
|
|
|
def __init__(self, reader: MessageReader, writer: MessageWriter) -> None:
|
|
self._reader: MessageReader = reader
|
|
self._writer: MessageWriter = writer
|
|
self._compression_algorithm: Optional[str] = None
|
|
self._warnings: List[str] = []
|
|
|
|
@property
|
|
def compression_algorithm(self) -> Optional[str]:
|
|
"""str: The compresion algorithm."""
|
|
return self._compression_algorithm
|
|
|
|
@staticmethod
|
|
def _apply_filter(msg: MessageType, stmt: FilterableStatement) -> None:
|
|
"""Apply filter.
|
|
|
|
Args:
|
|
msg (mysqlx.protobuf.Message): The MySQL X Protobuf Message.
|
|
stmt (Statement): A `Statement` based type object.
|
|
"""
|
|
if stmt.has_where:
|
|
msg["criteria"] = stmt.get_where_expr()
|
|
if stmt.has_sort:
|
|
msg["order"].extend(stmt.get_sort_expr())
|
|
if stmt.has_group_by:
|
|
msg["grouping"].extend(stmt.get_grouping())
|
|
if stmt.has_having:
|
|
msg["grouping_criteria"] = stmt.get_having()
|
|
|
|
def _create_any(self, arg: Any) -> Optional[MessageType]:
|
|
"""Create any.
|
|
|
|
Args:
|
|
arg (object): Arbitrary object.
|
|
|
|
Returns:
|
|
mysqlx.protobuf.Message: MySQL X Protobuf Message.
|
|
"""
|
|
if isinstance(arg, str):
|
|
value = Message("Mysqlx.Datatypes.Scalar.String", value=arg)
|
|
scalar = Message("Mysqlx.Datatypes.Scalar", type=8, v_string=value)
|
|
return Message("Mysqlx.Datatypes.Any", type=1, scalar=scalar)
|
|
if isinstance(arg, bool):
|
|
return Message(
|
|
"Mysqlx.Datatypes.Any", type=1, scalar=build_bool_scalar(arg)
|
|
)
|
|
if isinstance(arg, int):
|
|
if arg < 0:
|
|
return Message(
|
|
"Mysqlx.Datatypes.Any",
|
|
type=1,
|
|
scalar=build_int_scalar(arg),
|
|
)
|
|
return Message(
|
|
"Mysqlx.Datatypes.Any",
|
|
type=1,
|
|
scalar=build_unsigned_int_scalar(arg),
|
|
)
|
|
if isinstance(arg, tuple) and len(arg) == 2:
|
|
arg_key, arg_value = arg
|
|
obj_fld = Message(
|
|
"Mysqlx.Datatypes.Object.ObjectField",
|
|
key=arg_key,
|
|
value=self._create_any(arg_value),
|
|
)
|
|
obj = Message("Mysqlx.Datatypes.Object", fld=[obj_fld.get_message()])
|
|
return Message("Mysqlx.Datatypes.Any", type=2, obj=obj)
|
|
if isinstance(arg, dict) or (
|
|
isinstance(arg, (list, tuple)) and isinstance(arg[0], dict)
|
|
):
|
|
array_values = []
|
|
for items in arg:
|
|
obj_flds = []
|
|
for key, value in items.items():
|
|
# Array can only handle Any types, Mysqlx.Datatypes.Any.obj
|
|
obj_fld = Message(
|
|
"Mysqlx.Datatypes.Object.ObjectField",
|
|
key=key,
|
|
value=self._create_any(value),
|
|
)
|
|
obj_flds.append(obj_fld.get_message())
|
|
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
|
|
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
|
|
array_values.append(msg_any.get_message())
|
|
|
|
msg = Message("Mysqlx.Datatypes.Array")
|
|
msg["value"] = array_values
|
|
return Message("Mysqlx.Datatypes.Any", type=3, array=msg)
|
|
if isinstance(arg, list):
|
|
obj_flds = []
|
|
for key, value in arg:
|
|
obj_fld = Message(
|
|
"Mysqlx.Datatypes.Object.ObjectField",
|
|
key=key,
|
|
value=self._create_any(value),
|
|
)
|
|
obj_flds.append(obj_fld.get_message())
|
|
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
|
|
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
|
|
return msg_any
|
|
|
|
return None
|
|
|
|
def _get_binding_args(
|
|
self, stmt: Union[FilterableStatement, SqlStatement], is_scalar: bool = True
|
|
) -> Union[List[None], List[Union[ProtobufMessageType, ProtobufMessageCextType]]]:
|
|
"""Returns the binding any/scalar.
|
|
|
|
Args:
|
|
stmt (Statement): A `Statement` based type object.
|
|
is_scalar (bool): `True` to return scalar values.
|
|
|
|
Raises:
|
|
:class:`mysqlx.ProgrammingError`: If unable to find placeholder for
|
|
parameter.
|
|
|
|
Returns:
|
|
list: A list of ``Any`` or ``Scalar`` objects.
|
|
"""
|
|
|
|
def build_value(
|
|
value: Any,
|
|
) -> Union[ProtobufMessageType, ProtobufMessageCextType]:
|
|
if is_scalar:
|
|
return build_scalar(value).get_message()
|
|
return self._create_any(value).get_message()
|
|
|
|
bindings = stmt.get_bindings()
|
|
binding_map = stmt.get_binding_map()
|
|
|
|
# If binding_map is None it's a SqlStatement object
|
|
if binding_map is None:
|
|
return [build_value(value) for value in bindings]
|
|
|
|
count = len(binding_map)
|
|
args: List[Any] = count * [None]
|
|
if count != len(bindings):
|
|
raise ProgrammingError(
|
|
"The number of bind parameters and placeholders do not match"
|
|
)
|
|
for name, value in bindings.items(): # type: ignore[union-attr]
|
|
if name not in binding_map:
|
|
raise ProgrammingError(
|
|
f"Unable to find placeholder for parameter: {name}"
|
|
)
|
|
pos = binding_map[name]
|
|
args[pos] = build_value(value)
|
|
return args
|
|
|
|
def _process_frame(self, msg: MessageType, result: ResultBaseType) -> None:
|
|
"""Process frame.
|
|
|
|
Args:
|
|
msg (mysqlx.protobuf.Message): A MySQL X Protobuf Message.
|
|
result (Result): A `Result` based type object.
|
|
"""
|
|
if msg["type"] == 1:
|
|
warn_msg = Message.from_message("Mysqlx.Notice.Warning", msg["payload"])
|
|
self._warnings.append(warn_msg.msg)
|
|
logger.warning(
|
|
"Protocol.process_frame Received Warning Notice code %s: %s",
|
|
warn_msg.code,
|
|
warn_msg.msg,
|
|
)
|
|
result.append_warning(warn_msg.level, warn_msg.code, warn_msg.msg)
|
|
elif msg["type"] == 2:
|
|
Message.from_message("Mysqlx.Notice.SessionVariableChanged", msg["payload"])
|
|
elif msg["type"] == 3:
|
|
sess_state_msg = Message.from_message(
|
|
"Mysqlx.Notice.SessionStateChanged", msg["payload"]
|
|
)
|
|
if sess_state_msg["param"] == mysqlxpb_enum(
|
|
"Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_DOCUMENT_IDS"
|
|
):
|
|
result.set_generated_ids(
|
|
[
|
|
get_item_or_attr(
|
|
get_item_or_attr(value, "v_octets"), "value"
|
|
).decode()
|
|
for value in sess_state_msg["value"]
|
|
]
|
|
)
|
|
else: # Following results are unitary and not a list
|
|
sess_state_value = sess_state_msg["value"].pop()
|
|
if sess_state_msg["param"] == mysqlxpb_enum(
|
|
"Mysqlx.Notice.SessionStateChanged.Parameter.ROWS_AFFECTED"
|
|
):
|
|
result.set_rows_affected(
|
|
get_item_or_attr(sess_state_value, "v_unsigned_int")
|
|
)
|
|
elif sess_state_msg["param"] == mysqlxpb_enum(
|
|
"Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_INSERT_ID"
|
|
):
|
|
result.set_generated_insert_id(
|
|
get_item_or_attr(sess_state_value, "v_unsigned_int")
|
|
)
|
|
|
|
def _read_message(self, result: ResultBaseType) -> Optional[MessageType]:
|
|
"""Read message.
|
|
|
|
Args:
|
|
result (Result): A `Result` based type object.
|
|
"""
|
|
while True:
|
|
try:
|
|
msg = self._reader.read_message()
|
|
except RuntimeError as err:
|
|
warnings = repr(result.get_warnings())
|
|
if warnings:
|
|
raise RuntimeError(f"{err} reason: {warnings}") from err
|
|
if msg.type == "Mysqlx.Error":
|
|
raise OperationalError(msg["msg"], msg["code"])
|
|
if msg.type == "Mysqlx.Notice.Frame":
|
|
try:
|
|
self._process_frame(msg, result)
|
|
except (AttributeError, KeyError):
|
|
continue
|
|
elif msg.type == "Mysqlx.Sql.StmtExecuteOk":
|
|
return None
|
|
elif msg.type == "Mysqlx.Resultset.FetchDone":
|
|
result.set_closed(True)
|
|
elif msg.type == "Mysqlx.Resultset.FetchDoneMoreResultsets":
|
|
result.set_has_more_results(True)
|
|
elif msg.type == "Mysqlx.Resultset.Row":
|
|
result.set_has_data(True)
|
|
break
|
|
else:
|
|
break
|
|
return msg
|
|
|
|
def set_compression(self, algorithm: str) -> None:
|
|
"""Sets the compression algorithm to be used by the compression
|
|
object, for uplink and downlink.
|
|
|
|
Args:
|
|
algorithm (str): Algorithm to be used in compression/decompression.
|
|
|
|
.. versionadded:: 8.0.21
|
|
|
|
"""
|
|
self._compression_algorithm = algorithm
|
|
self._reader.set_compression(algorithm)
|
|
self._writer.set_compression(algorithm)
|
|
|
|
def get_capabilites(self) -> MessageType:
|
|
"""Get capabilities.
|
|
|
|
Returns:
|
|
mysqlx.protobuf.Message: MySQL X Protobuf Message.
|
|
"""
|
|
msg = Message("Mysqlx.Connection.CapabilitiesGet")
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_GET"),
|
|
msg,
|
|
)
|
|
msg = self._reader.read_message()
|
|
while msg.type == "Mysqlx.Notice.Frame":
|
|
msg = self._reader.read_message()
|
|
|
|
if msg.type == "Mysqlx.Error":
|
|
raise OperationalError(msg["msg"], msg["code"])
|
|
|
|
return msg
|
|
|
|
def set_capabilities(self, **kwargs: Any) -> None:
|
|
"""Set capabilities.
|
|
|
|
Args:
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
mysqlx.protobuf.Message: MySQL X Protobuf Message.
|
|
"""
|
|
if not kwargs:
|
|
return None
|
|
capabilities = Message("Mysqlx.Connection.Capabilities")
|
|
for key, value in kwargs.items():
|
|
capability = Message("Mysqlx.Connection.Capability")
|
|
capability["name"] = key
|
|
if isinstance(value, dict):
|
|
items = value
|
|
obj_flds = []
|
|
for item in items:
|
|
obj_fld = Message(
|
|
"Mysqlx.Datatypes.Object.ObjectField",
|
|
key=item,
|
|
value=self._create_any(items[item]),
|
|
)
|
|
obj_flds.append(obj_fld.get_message())
|
|
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
|
|
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
|
|
capability["value"] = msg_any.get_message()
|
|
else:
|
|
capability["value"] = self._create_any(value)
|
|
|
|
capabilities["capabilities"].extend([capability.get_message()])
|
|
msg = Message("Mysqlx.Connection.CapabilitiesSet")
|
|
msg["capabilities"] = capabilities
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_SET"),
|
|
msg,
|
|
)
|
|
|
|
try:
|
|
return self.read_ok()
|
|
except InterfaceError as err:
|
|
# Skip capability "session_connect_attrs" error since
|
|
# is only available on version >= 8.0.16
|
|
if err.errno != 5002:
|
|
raise
|
|
return None
|
|
|
|
def send_auth_start(
|
|
self,
|
|
method: str,
|
|
auth_data: Optional[str] = None,
|
|
initial_response: Optional[str] = None,
|
|
) -> None:
|
|
"""Send authenticate start.
|
|
|
|
Args:
|
|
method (str): Message method.
|
|
auth_data (Optional[str]): Authentication data.
|
|
initial_response (Optional[str]): Initial response.
|
|
"""
|
|
msg = Message("Mysqlx.Session.AuthenticateStart")
|
|
msg["mech_name"] = method
|
|
if auth_data is not None:
|
|
msg["auth_data"] = auth_data
|
|
if initial_response is not None:
|
|
msg["initial_response"] = initial_response
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_START"),
|
|
msg,
|
|
)
|
|
|
|
def read_auth_continue(self) -> bytes:
|
|
"""Read authenticate continue.
|
|
|
|
Raises:
|
|
:class:`InterfaceError`: If the message type is not
|
|
`Mysqlx.Session.AuthenticateContinue`
|
|
|
|
Returns:
|
|
str: The authentication data.
|
|
"""
|
|
msg = self._reader.read_message()
|
|
while msg.type == "Mysqlx.Notice.Frame":
|
|
msg = self._reader.read_message()
|
|
if msg.type != "Mysqlx.Session.AuthenticateContinue":
|
|
raise InterfaceError(
|
|
"Unexpected message encountered during authentication handshake"
|
|
)
|
|
return msg["auth_data"]
|
|
|
|
def send_auth_continue(self, auth_data: str) -> None:
|
|
"""Send authenticate continue.
|
|
|
|
Args:
|
|
auth_data (str): Authentication data.
|
|
"""
|
|
msg = Message("Mysqlx.Session.AuthenticateContinue", auth_data=auth_data)
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_CONTINUE"),
|
|
msg,
|
|
)
|
|
|
|
def read_auth_ok(self) -> None:
|
|
"""Read authenticate OK.
|
|
|
|
Raises:
|
|
:class:`mysqlx.InterfaceError`: If message type is `Mysqlx.Error`.
|
|
"""
|
|
while True:
|
|
msg = self._reader.read_message()
|
|
if msg.type == "Mysqlx.Session.AuthenticateOk":
|
|
break
|
|
if msg.type == "Mysqlx.Error":
|
|
raise InterfaceError(msg.msg)
|
|
|
|
def send_prepare_prepare(
|
|
self,
|
|
msg_type: str,
|
|
msg: MessageType,
|
|
stmt: Union[
|
|
FindStatement,
|
|
DeleteStatement,
|
|
ModifyStatement,
|
|
ReadStatement,
|
|
RemoveStatement,
|
|
UpdateStatement,
|
|
],
|
|
) -> None:
|
|
"""
|
|
Send prepare statement.
|
|
|
|
Args:
|
|
msg_type (str): Message ID string.
|
|
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
|
|
stmt (Statement): A `Statement` based type object.
|
|
|
|
Raises:
|
|
:class:`mysqlx.NotSupportedError`: If prepared statements are not
|
|
supported.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
if stmt.has_limit and msg.type != "Mysqlx.Crud.Insert":
|
|
# Remove 'limit' from message by building a new one
|
|
if msg.type == "Mysqlx.Crud.Find":
|
|
_, msg = self.build_find(stmt) # type: ignore[arg-type]
|
|
elif msg.type == "Mysqlx.Crud.Update":
|
|
_, msg = self.build_update(stmt) # type: ignore[arg-type]
|
|
elif msg.type == "Mysqlx.Crud.Delete":
|
|
_, msg = self.build_delete(stmt) # type: ignore[arg-type]
|
|
else:
|
|
raise ValueError(f"Invalid message type: {msg_type}")
|
|
# Build 'limit_expr' message
|
|
position = len(stmt.get_bindings())
|
|
placeholder = mysqlxpb_enum("Mysqlx.Expr.Expr.Type.PLACEHOLDER")
|
|
msg_limit_expr = Message("Mysqlx.Crud.LimitExpr")
|
|
msg_limit_expr["row_count"] = Message(
|
|
"Mysqlx.Expr.Expr", type=placeholder, position=position
|
|
)
|
|
if msg.type == "Mysqlx.Crud.Find":
|
|
msg_limit_expr["offset"] = Message(
|
|
"Mysqlx.Expr.Expr", type=placeholder, position=position + 1
|
|
)
|
|
msg["limit_expr"] = msg_limit_expr
|
|
|
|
oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type]
|
|
msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage")
|
|
msg_oneof["type"] = mysqlxpb_enum(oneof_type)
|
|
msg_oneof[oneof_op] = msg
|
|
msg_prepare = Message("Mysqlx.Prepare.Prepare")
|
|
msg_prepare["stmt_id"] = stmt.stmt_id
|
|
msg_prepare["stmt"] = msg_oneof
|
|
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_PREPARE"),
|
|
msg_prepare,
|
|
)
|
|
|
|
try:
|
|
self.read_ok()
|
|
except InterfaceError as err:
|
|
raise NotSupportedError from err
|
|
|
|
def send_prepare_execute(
|
|
self, msg_type: str, msg: MessageType, stmt: FilterableStatement
|
|
) -> None:
|
|
"""
|
|
Send execute statement.
|
|
|
|
Args:
|
|
msg_type (str): Message ID string.
|
|
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
|
|
stmt (Statement): A `Statement` based type object.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type]
|
|
msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage")
|
|
msg_oneof["type"] = mysqlxpb_enum(oneof_type)
|
|
msg_oneof[oneof_op] = msg
|
|
msg_execute = Message("Mysqlx.Prepare.Execute")
|
|
msg_execute["stmt_id"] = stmt.stmt_id
|
|
|
|
args = self._get_binding_args(stmt, is_scalar=False)
|
|
if args:
|
|
msg_execute["args"].extend(args)
|
|
|
|
if stmt.has_limit:
|
|
msg_execute["args"].extend(
|
|
[
|
|
self._create_any(stmt.get_limit_row_count()).get_message(),
|
|
self._create_any(stmt.get_limit_offset()).get_message(),
|
|
]
|
|
)
|
|
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_EXECUTE"),
|
|
msg_execute,
|
|
)
|
|
|
|
def send_prepare_deallocate(self, stmt_id: int) -> None:
|
|
"""
|
|
Send prepare deallocate statement.
|
|
|
|
Args:
|
|
stmt_id (int): Statement ID.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
msg_dealloc = Message("Mysqlx.Prepare.Deallocate")
|
|
msg_dealloc["stmt_id"] = stmt_id
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_DEALLOCATE"),
|
|
msg_dealloc,
|
|
)
|
|
self.read_ok()
|
|
|
|
def send_msg_without_ps(
|
|
self,
|
|
msg_type: str,
|
|
msg: MessageType,
|
|
stmt: Union[FilterableStatement, SqlStatement],
|
|
) -> None:
|
|
"""
|
|
Send a message without prepared statements support.
|
|
|
|
Args:
|
|
msg_type (str): Message ID string.
|
|
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
|
|
stmt (Statement): A `Statement` based type object.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
if stmt.has_limit:
|
|
msg_limit = Message("Mysqlx.Crud.Limit")
|
|
msg_limit["row_count"] = stmt.get_limit_row_count() # type: ignore[union-attr]
|
|
if msg.type == "Mysqlx.Crud.Find":
|
|
msg_limit["offset"] = stmt.get_limit_offset() # type: ignore[union-attr]
|
|
msg["limit"] = msg_limit
|
|
is_scalar = msg_type != "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE"
|
|
args = self._get_binding_args(stmt, is_scalar=is_scalar)
|
|
if args:
|
|
msg["args"].extend(args)
|
|
self.send_msg(msg_type, msg)
|
|
|
|
def send_msg(self, msg_type: str, msg: MessageType) -> None:
|
|
"""
|
|
Send a message.
|
|
|
|
Args:
|
|
msg_type (str): Message ID string.
|
|
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
self._writer.write_message(mysqlxpb_enum(msg_type), msg)
|
|
|
|
def build_find(
|
|
self, stmt: Union[FindStatement, ReadStatement]
|
|
) -> Tuple[str, MessageType]:
|
|
"""Build find/read message.
|
|
|
|
Args:
|
|
stmt (Statement): A :class:`mysqlx.ReadStatement` or
|
|
:class:`mysqlx.FindStatement` object.
|
|
|
|
Returns:
|
|
(tuple): Tuple containing:
|
|
|
|
* `str`: Message ID string.
|
|
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
data_model = mysqlxpb_enum(
|
|
"Mysqlx.Crud.DataModel.DOCUMENT"
|
|
if stmt.is_doc_based()
|
|
else "Mysqlx.Crud.DataModel.TABLE"
|
|
)
|
|
collection = Message(
|
|
"Mysqlx.Crud.Collection",
|
|
name=stmt.target.name,
|
|
schema=stmt.schema.name,
|
|
)
|
|
msg = Message("Mysqlx.Crud.Find", data_model=data_model, collection=collection)
|
|
if stmt.has_projection:
|
|
msg["projection"] = stmt.get_projection_expr()
|
|
self._apply_filter(msg, stmt)
|
|
|
|
if stmt.is_lock_exclusive():
|
|
msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.EXCLUSIVE_LOCK")
|
|
elif stmt.is_lock_shared():
|
|
msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.SHARED_LOCK")
|
|
|
|
if stmt.lock_contention.value > 0:
|
|
msg["locking_options"] = stmt.lock_contention.value
|
|
|
|
return "Mysqlx.ClientMessages.Type.CRUD_FIND", msg
|
|
|
|
def build_update(
|
|
self, stmt: Union[ModifyStatement, UpdateStatement]
|
|
) -> Tuple[str, MessageType]:
|
|
"""Build update message.
|
|
|
|
Args:
|
|
stmt (Statement): A :class:`mysqlx.ModifyStatement` or
|
|
:class:`mysqlx.UpdateStatement` object.
|
|
|
|
Returns:
|
|
(tuple): Tuple containing:
|
|
|
|
* `str`: Message ID string.
|
|
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
data_model = mysqlxpb_enum(
|
|
"Mysqlx.Crud.DataModel.DOCUMENT"
|
|
if stmt.is_doc_based()
|
|
else "Mysqlx.Crud.DataModel.TABLE"
|
|
)
|
|
collection = Message(
|
|
"Mysqlx.Crud.Collection",
|
|
name=stmt.target.name,
|
|
schema=stmt.schema.name,
|
|
)
|
|
msg = Message(
|
|
"Mysqlx.Crud.Update", data_model=data_model, collection=collection
|
|
)
|
|
self._apply_filter(msg, stmt)
|
|
for _, update_op in stmt.get_update_ops().items():
|
|
operation = Message("Mysqlx.Crud.UpdateOperation")
|
|
operation["operation"] = update_op.update_type
|
|
operation["source"] = update_op.source
|
|
if update_op.value is not None:
|
|
operation["value"] = build_expr(update_op.value)
|
|
msg["operation"].extend([operation.get_message()])
|
|
|
|
return "Mysqlx.ClientMessages.Type.CRUD_UPDATE", msg
|
|
|
|
def build_delete(
|
|
self, stmt: Union[DeleteStatement, RemoveStatement]
|
|
) -> Tuple[str, MessageType]:
|
|
"""Build delete message.
|
|
|
|
Args:
|
|
stmt (Statement): A :class:`mysqlx.DeleteStatement` or
|
|
:class:`mysqlx.RemoveStatement` object.
|
|
|
|
Returns:
|
|
(tuple): Tuple containing:
|
|
|
|
* `str`: Message ID string.
|
|
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
data_model = mysqlxpb_enum(
|
|
"Mysqlx.Crud.DataModel.DOCUMENT"
|
|
if stmt.is_doc_based()
|
|
else "Mysqlx.Crud.DataModel.TABLE"
|
|
)
|
|
collection = Message(
|
|
"Mysqlx.Crud.Collection",
|
|
name=stmt.target.name,
|
|
schema=stmt.schema.name,
|
|
)
|
|
msg = Message(
|
|
"Mysqlx.Crud.Delete", data_model=data_model, collection=collection
|
|
)
|
|
self._apply_filter(msg, stmt)
|
|
return "Mysqlx.ClientMessages.Type.CRUD_DELETE", msg
|
|
|
|
def build_execute_statement(
|
|
self,
|
|
namespace: str,
|
|
stmt: Union[str, StatementType],
|
|
fields: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[str, MessageType]:
|
|
"""Build execute statement.
|
|
|
|
Args:
|
|
namespace (str): The namespace.
|
|
stmt (Statement): A `Statement` based type object.
|
|
fields (Optional[dict]): The message fields.
|
|
|
|
Returns:
|
|
(tuple): Tuple containing:
|
|
|
|
* `str`: Message ID string.
|
|
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
msg = Message(
|
|
"Mysqlx.Sql.StmtExecute",
|
|
namespace=namespace,
|
|
stmt=stmt,
|
|
compact_metadata=False,
|
|
)
|
|
|
|
if fields:
|
|
obj_flds = []
|
|
for key, value in fields.items():
|
|
obj_fld = Message(
|
|
"Mysqlx.Datatypes.Object.ObjectField",
|
|
key=key,
|
|
value=self._create_any(value),
|
|
)
|
|
obj_flds.append(obj_fld.get_message())
|
|
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
|
|
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
|
|
msg["args"] = [msg_any.get_message()]
|
|
return "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE", msg
|
|
|
|
@staticmethod
|
|
def build_insert(
|
|
stmt: Union[AddStatement, InsertStatement]
|
|
) -> Tuple[str, MessageType]:
|
|
"""Build insert statement.
|
|
|
|
Args:
|
|
stmt (Statement): A :class:`mysqlx.AddStatement` or
|
|
:class:`mysqlx.InsertStatement` object.
|
|
|
|
Returns:
|
|
(tuple): Tuple containing:
|
|
|
|
* `str`: Message ID string.
|
|
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
|
|
|
|
.. versionadded:: 8.0.16
|
|
"""
|
|
data_model = mysqlxpb_enum(
|
|
"Mysqlx.Crud.DataModel.DOCUMENT"
|
|
if stmt.is_doc_based()
|
|
else "Mysqlx.Crud.DataModel.TABLE"
|
|
)
|
|
collection = Message(
|
|
"Mysqlx.Crud.Collection",
|
|
name=stmt.target.name,
|
|
schema=stmt.schema.name,
|
|
)
|
|
msg = Message(
|
|
"Mysqlx.Crud.Insert", data_model=data_model, collection=collection
|
|
)
|
|
|
|
if hasattr(stmt, "_fields"):
|
|
for field in stmt._fields:
|
|
expr = ExprParser(
|
|
field, not stmt.is_doc_based()
|
|
).parse_table_insert_field()
|
|
msg["projection"].extend([expr.get_message()])
|
|
|
|
for value in stmt.get_values():
|
|
row = Message("Mysqlx.Crud.Insert.TypedRow")
|
|
if isinstance(value, list):
|
|
for val in value:
|
|
row["field"].extend([build_expr(val).get_message()])
|
|
else:
|
|
row["field"].extend([build_expr(value).get_message()])
|
|
msg["row"].extend([row.get_message()])
|
|
|
|
if hasattr(stmt, "is_upsert"):
|
|
msg["upsert"] = stmt.is_upsert()
|
|
|
|
return "Mysqlx.ClientMessages.Type.CRUD_INSERT", msg
|
|
|
|
def close_result(self, result: ResultBaseType) -> None:
|
|
"""Close the result.
|
|
|
|
Args:
|
|
result (Result): A `Result` based type object.
|
|
|
|
Raises:
|
|
:class:`mysqlx.OperationalError`: If message read is None.
|
|
"""
|
|
msg = self._read_message(result)
|
|
if msg is not None:
|
|
raise OperationalError("Expected to close the result")
|
|
|
|
def read_row(self, result: ResultBaseType) -> Optional[MessageType]:
|
|
"""Read row.
|
|
|
|
Args:
|
|
result (Result): A `Result` based type object.
|
|
"""
|
|
msg = self._read_message(result)
|
|
if msg is None:
|
|
return None
|
|
if msg.type == "Mysqlx.Resultset.Row":
|
|
return msg
|
|
self._reader.push_message(msg)
|
|
return None
|
|
|
|
def get_column_metadata(self, result: ResultBaseType) -> List[ColumnType]:
|
|
"""Returns column metadata.
|
|
|
|
Args:
|
|
result (Result): A `Result` based type object.
|
|
|
|
Raises:
|
|
:class:`mysqlx.InterfaceError`: If unexpected message.
|
|
"""
|
|
columns = []
|
|
while True:
|
|
msg = self._read_message(result)
|
|
if msg is None:
|
|
break
|
|
if msg.type == "Mysqlx.Resultset.Row":
|
|
self._reader.push_message(msg)
|
|
break
|
|
if msg.type != "Mysqlx.Resultset.ColumnMetaData":
|
|
raise InterfaceError("Unexpected msg type")
|
|
col = Column(
|
|
msg["type"],
|
|
msg["catalog"],
|
|
msg["schema"],
|
|
msg["table"],
|
|
msg["original_table"],
|
|
msg["name"],
|
|
msg["original_name"],
|
|
msg.get("length", 21),
|
|
msg.get("collation", 0),
|
|
msg.get("fractional_digits", 0),
|
|
msg.get("flags", 16),
|
|
msg.get("content_type"),
|
|
)
|
|
columns.append(col)
|
|
return columns
|
|
|
|
def read_ok(self) -> None:
|
|
"""Read OK.
|
|
|
|
Raises:
|
|
:class:`mysqlx.InterfaceError`: If unexpected message.
|
|
"""
|
|
msg = self._reader.read_message()
|
|
if msg.type == "Mysqlx.Error":
|
|
raise InterfaceError(f"Mysqlx.Error: {msg['msg']}", errno=msg["code"])
|
|
if msg.type != "Mysqlx.Ok":
|
|
raise InterfaceError("Unexpected message encountered")
|
|
|
|
def send_connection_close(self) -> None:
|
|
"""Send connection close."""
|
|
msg = Message("Mysqlx.Connection.Close")
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CLOSE"), msg
|
|
)
|
|
|
|
def send_close(self) -> None:
|
|
"""Send close."""
|
|
msg = Message("Mysqlx.Session.Close")
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_CLOSE"), msg
|
|
)
|
|
|
|
def send_expect_open(self) -> None:
|
|
"""Send expectation."""
|
|
cond_key = mysqlxpb_enum("Mysqlx.Expect.Open.Condition.Key.EXPECT_FIELD_EXIST")
|
|
msg_oc = Message("Mysqlx.Expect.Open.Condition")
|
|
msg_oc["condition_key"] = cond_key
|
|
msg_oc["condition_value"] = "6.1"
|
|
|
|
msg_eo = Message("Mysqlx.Expect.Open")
|
|
msg_eo["cond"] = [msg_oc.get_message()]
|
|
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.EXPECT_OPEN"), msg_eo
|
|
)
|
|
|
|
def send_reset(self, keep_open: Optional[bool] = None) -> bool:
|
|
"""Send reset session message.
|
|
|
|
Returns:
|
|
boolean: ``True`` if the server will keep the session open,
|
|
otherwise ``False``.
|
|
"""
|
|
msg = Message("Mysqlx.Session.Reset")
|
|
if keep_open is None:
|
|
try:
|
|
# Send expectation: keep connection open
|
|
self.send_expect_open()
|
|
self.read_ok()
|
|
keep_open = True
|
|
except InterfaceError:
|
|
# Expectation is unkown by this version of the server
|
|
keep_open = False
|
|
if keep_open:
|
|
msg["keep_open"] = True
|
|
self._writer.write_message(
|
|
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_RESET"), msg
|
|
)
|
|
self.read_ok()
|
|
if keep_open:
|
|
return True
|
|
return False
|