AzerothCore-website/venv/lib/python3.12/site-packages/mysqlx/protocol.py
Aaron Barbas 9bbeb35c08 Added support link to download game client, link for addons.
Fixed an issue that prevented the password reset tokens from working.
Added email templates for password reset success and new account creation.
Added more dynamic email template support.
2024-10-03 22:00:40 -05:00

1214 lines
42 KiB
Python

# Copyright (c) 2016, 2023, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implementation of the X protocol for MySQL servers."""
import struct
import zlib
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
try:
import lz4.frame
HAVE_LZ4 = True
except ImportError:
HAVE_LZ4 = False
try:
import zstandard as zstd
HAVE_ZSTD = True
except ImportError:
HAVE_ZSTD = False
from .errors import (
InterfaceError,
NotSupportedError,
OperationalError,
ProgrammingError,
)
from .expr import (
ExprParser,
build_bool_scalar,
build_expr,
build_int_scalar,
build_scalar,
build_unsigned_int_scalar,
)
from .helpers import encode_to_bytes, get_item_or_attr
from .logger import logger
from .protobuf import CRUD_PREPARE_MAPPING, SERVER_MESSAGES, Message, mysqlxpb_enum
from .result import Column
from .statement import (
AddStatement,
DeleteStatement,
FilterableStatement,
FindStatement,
InsertStatement,
ModifyStatement,
ReadStatement,
RemoveStatement,
SqlStatement,
UpdateStatement,
)
from .types import (
ColumnType,
MessageType,
ProtobufMessageCextType,
ProtobufMessageType,
ResultBaseType,
SocketType,
StatementType,
StrOrBytes,
)
_COMPRESSION_THRESHOLD = 1000
class Compressor:
"""Implements compression/decompression using `zstd_stream`, `lz4_message`
and `deflate_stream` algorithms.
Args:
algorithm (str): Compression algorithm.
.. versionadded:: 8.0.21
"""
def __init__(self, algorithm: str) -> None:
self._algorithm: str = algorithm
self._compressobj: Any = None
self._decompressobj: Any = None
if algorithm == "zstd_stream":
self._compressobj = zstd.ZstdCompressor()
self._decompressobj = zstd.ZstdDecompressor()
elif algorithm == "deflate_stream":
self._compressobj = zlib.compressobj()
self._decompressobj = zlib.decompressobj()
def compress(self, data: StrOrBytes) -> bytes:
"""Compresses data and returns it.
Args:
data (str, bytes or buffer object): Data to be compressed.
Returns:
bytes: Compressed data.
"""
if self._algorithm == "zstd_stream":
return self._compressobj.compress(data)
if self._algorithm == "lz4_message":
with lz4.frame.LZ4FrameCompressor() as compressor:
compressed = compressor.begin()
compressed += compressor.compress(data)
compressed += compressor.flush()
return compressed
# Using 'deflate_stream' algorithm
compressed = self._compressobj.compress(data)
compressed += self._compressobj.flush(zlib.Z_SYNC_FLUSH)
return compressed
def decompress(self, data: StrOrBytes) -> bytes:
"""Decompresses a frame of data and returns it as a string of bytes.
Args:
data (str, bytes or buffer object): Data to be compressed.
Returns:
bytes: Decompresssed data.
"""
if self._algorithm == "zstd_stream":
return self._decompressobj.decompress(data)
if self._algorithm == "lz4_message":
with lz4.frame.LZ4FrameDecompressor() as decompressor:
decompressed = decompressor.decompress(data)
return decompressed
# Using 'deflate' algorithm
decompressed = self._decompressobj.decompress(data)
decompressed += self._decompressobj.flush(zlib.Z_SYNC_FLUSH)
return decompressed
class MessageReader:
"""Implements a Message Reader.
Args:
socket_stream (mysqlx.connection.SocketStream): `SocketStream` object.
.. versionadded:: 8.0.21
"""
def __init__(self, socket_stream: SocketType) -> None:
self._stream: SocketType = socket_stream
self._compressor: Optional[Compressor] = None
self._msg: MessageType = None
self._msg_queue: List[Message] = []
def _read_message(self) -> MessageType:
"""Reads X Protocol messages from the stream and returns a
:class:`mysqlx.protobuf.Message` object.
Raises:
:class:`mysqlx.ProgrammingError`: If e connected server does not
have the MySQL X protocol plugin
enabled.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if self._msg_queue:
return self._msg_queue.pop(0)
frame_size, frame_type = struct.unpack("<LB", self._stream.read(5))
if frame_type == 10:
raise ProgrammingError(
"The connected server does not have the "
"MySQL X protocol plugin enabled or "
"protocol mismatch"
)
frame_payload = self._stream.read(frame_size - 1)
if frame_type not in SERVER_MESSAGES:
raise ValueError(f"Unknown message type: {frame_type}")
# Do not parse empty notices, Message requires a type in payload
if frame_type == 11 and frame_payload == b"":
return self._read_message()
frame_msg = Message.from_server_message(frame_type, frame_payload)
if frame_type == 19: # Mysqlx.ServerMessages.Type.COMPRESSION
uncompressed_size = frame_msg["uncompressed_size"]
stream = BytesIO(self._compressor.decompress(frame_msg["payload"]))
bytes_processed = 0
while bytes_processed < uncompressed_size:
payload_size, msg_type = struct.unpack("<LB", stream.read(5))
payload = stream.read(payload_size - 1)
self._msg_queue.append(Message.from_server_message(msg_type, payload))
bytes_processed += payload_size + 4
return self._msg_queue.pop(0) if self._msg_queue else None
return frame_msg
def read_message(self) -> MessageType:
"""Read message.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if self._msg is not None:
msg = self._msg
self._msg = None
return msg
return self._read_message()
def push_message(self, msg: MessageType) -> None:
"""Push message.
Args:
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
Raises:
:class:`mysqlx.OperationalError`: If message push slot is full.
"""
if self._msg is not None:
raise OperationalError("Message push slot is full")
self._msg = msg
def set_compression(self, algorithm: str) -> None:
"""Creates a :class:`mysqlx.protocol.Compressor` object based on the
compression algorithm.
Args:
algorithm (str): Compression algorithm.
.. versionadded:: 8.0.21
"""
self._compressor = Compressor(algorithm) if algorithm else None
class MessageWriter:
"""Implements a Message Writer.
Args:
socket_stream (mysqlx.connection.SocketStream): `SocketStream` object.
.. versionadded:: 8.0.21
"""
def __init__(self, socket_stream: SocketType) -> None:
self._stream: SocketType = socket_stream
self._compressor: Optional[Compressor] = None
def write_message(self, msg_type: int, msg: MessageType) -> None:
"""Write message.
Args:
msg_type (int): The message type.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
"""
msg_size = msg.byte_size(msg)
if self._compressor and msg_size > _COMPRESSION_THRESHOLD:
msg_str = encode_to_bytes(msg.serialize_to_string())
header = struct.pack("<LB", msg_size + 1, msg_type)
compressed = self._compressor.compress(b"".join([header, msg_str]))
msg_first_fields = Message("Mysqlx.Connection.Compression")
msg_first_fields["client_messages"] = msg_type
msg_first_fields["uncompressed_size"] = msg_size + 5
msg_payload = Message("Mysqlx.Connection.Compression")
msg_payload["payload"] = compressed
output = b"".join(
[
encode_to_bytes(msg_first_fields.serialize_partial_to_string())[
:-2
],
encode_to_bytes(msg_payload.serialize_partial_to_string()),
]
)
msg_comp_id = mysqlxpb_enum("Mysqlx.ClientMessages.Type.COMPRESSION")
header = struct.pack("<LB", len(output) + 1, msg_comp_id)
self._stream.sendall(b"".join([header, output]))
else:
msg_str = encode_to_bytes(msg.serialize_to_string())
header = struct.pack("<LB", msg_size + 1, msg_type)
self._stream.sendall(b"".join([header, msg_str]))
def set_compression(self, algorithm: str) -> None:
"""Creates a :class:`mysqlx.protocol.Compressor` object based on the
compression algorithm.
Args:
algorithm (str): Compression algorithm.
"""
self._compressor = Compressor(algorithm) if algorithm else None
class Protocol:
"""Implements the MySQL X Protocol.
Args:
read (mysqlx.protocol.MessageReader): A Message Reader object.
writer (mysqlx.protocol.MessageWriter): A Message Writer object.
.. versionchanged:: 8.0.21
"""
def __init__(self, reader: MessageReader, writer: MessageWriter) -> None:
self._reader: MessageReader = reader
self._writer: MessageWriter = writer
self._compression_algorithm: Optional[str] = None
self._warnings: List[str] = []
@property
def compression_algorithm(self) -> Optional[str]:
"""str: The compresion algorithm."""
return self._compression_algorithm
@staticmethod
def _apply_filter(msg: MessageType, stmt: FilterableStatement) -> None:
"""Apply filter.
Args:
msg (mysqlx.protobuf.Message): The MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
"""
if stmt.has_where:
msg["criteria"] = stmt.get_where_expr()
if stmt.has_sort:
msg["order"].extend(stmt.get_sort_expr())
if stmt.has_group_by:
msg["grouping"].extend(stmt.get_grouping())
if stmt.has_having:
msg["grouping_criteria"] = stmt.get_having()
def _create_any(self, arg: Any) -> Optional[MessageType]:
"""Create any.
Args:
arg (object): Arbitrary object.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if isinstance(arg, str):
value = Message("Mysqlx.Datatypes.Scalar.String", value=arg)
scalar = Message("Mysqlx.Datatypes.Scalar", type=8, v_string=value)
return Message("Mysqlx.Datatypes.Any", type=1, scalar=scalar)
if isinstance(arg, bool):
return Message(
"Mysqlx.Datatypes.Any", type=1, scalar=build_bool_scalar(arg)
)
if isinstance(arg, int):
if arg < 0:
return Message(
"Mysqlx.Datatypes.Any",
type=1,
scalar=build_int_scalar(arg),
)
return Message(
"Mysqlx.Datatypes.Any",
type=1,
scalar=build_unsigned_int_scalar(arg),
)
if isinstance(arg, tuple) and len(arg) == 2:
arg_key, arg_value = arg
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=arg_key,
value=self._create_any(arg_value),
)
obj = Message("Mysqlx.Datatypes.Object", fld=[obj_fld.get_message()])
return Message("Mysqlx.Datatypes.Any", type=2, obj=obj)
if isinstance(arg, dict) or (
isinstance(arg, (list, tuple)) and isinstance(arg[0], dict)
):
array_values = []
for items in arg:
obj_flds = []
for key, value in items.items():
# Array can only handle Any types, Mysqlx.Datatypes.Any.obj
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=key,
value=self._create_any(value),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
array_values.append(msg_any.get_message())
msg = Message("Mysqlx.Datatypes.Array")
msg["value"] = array_values
return Message("Mysqlx.Datatypes.Any", type=3, array=msg)
if isinstance(arg, list):
obj_flds = []
for key, value in arg:
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=key,
value=self._create_any(value),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
return msg_any
return None
def _get_binding_args(
self, stmt: Union[FilterableStatement, SqlStatement], is_scalar: bool = True
) -> Union[List[None], List[Union[ProtobufMessageType, ProtobufMessageCextType]]]:
"""Returns the binding any/scalar.
Args:
stmt (Statement): A `Statement` based type object.
is_scalar (bool): `True` to return scalar values.
Raises:
:class:`mysqlx.ProgrammingError`: If unable to find placeholder for
parameter.
Returns:
list: A list of ``Any`` or ``Scalar`` objects.
"""
def build_value(
value: Any,
) -> Union[ProtobufMessageType, ProtobufMessageCextType]:
if is_scalar:
return build_scalar(value).get_message()
return self._create_any(value).get_message()
bindings = stmt.get_bindings()
binding_map = stmt.get_binding_map()
# If binding_map is None it's a SqlStatement object
if binding_map is None:
return [build_value(value) for value in bindings]
count = len(binding_map)
args: List[Any] = count * [None]
if count != len(bindings):
raise ProgrammingError(
"The number of bind parameters and placeholders do not match"
)
for name, value in bindings.items(): # type: ignore[union-attr]
if name not in binding_map:
raise ProgrammingError(
f"Unable to find placeholder for parameter: {name}"
)
pos = binding_map[name]
args[pos] = build_value(value)
return args
def _process_frame(self, msg: MessageType, result: ResultBaseType) -> None:
"""Process frame.
Args:
msg (mysqlx.protobuf.Message): A MySQL X Protobuf Message.
result (Result): A `Result` based type object.
"""
if msg["type"] == 1:
warn_msg = Message.from_message("Mysqlx.Notice.Warning", msg["payload"])
self._warnings.append(warn_msg.msg)
logger.warning(
"Protocol.process_frame Received Warning Notice code %s: %s",
warn_msg.code,
warn_msg.msg,
)
result.append_warning(warn_msg.level, warn_msg.code, warn_msg.msg)
elif msg["type"] == 2:
Message.from_message("Mysqlx.Notice.SessionVariableChanged", msg["payload"])
elif msg["type"] == 3:
sess_state_msg = Message.from_message(
"Mysqlx.Notice.SessionStateChanged", msg["payload"]
)
if sess_state_msg["param"] == mysqlxpb_enum(
"Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_DOCUMENT_IDS"
):
result.set_generated_ids(
[
get_item_or_attr(
get_item_or_attr(value, "v_octets"), "value"
).decode()
for value in sess_state_msg["value"]
]
)
else: # Following results are unitary and not a list
sess_state_value = sess_state_msg["value"].pop()
if sess_state_msg["param"] == mysqlxpb_enum(
"Mysqlx.Notice.SessionStateChanged.Parameter.ROWS_AFFECTED"
):
result.set_rows_affected(
get_item_or_attr(sess_state_value, "v_unsigned_int")
)
elif sess_state_msg["param"] == mysqlxpb_enum(
"Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_INSERT_ID"
):
result.set_generated_insert_id(
get_item_or_attr(sess_state_value, "v_unsigned_int")
)
def _read_message(self, result: ResultBaseType) -> Optional[MessageType]:
"""Read message.
Args:
result (Result): A `Result` based type object.
"""
while True:
try:
msg = self._reader.read_message()
except RuntimeError as err:
warnings = repr(result.get_warnings())
if warnings:
raise RuntimeError(f"{err} reason: {warnings}") from err
if msg.type == "Mysqlx.Error":
raise OperationalError(msg["msg"], msg["code"])
if msg.type == "Mysqlx.Notice.Frame":
try:
self._process_frame(msg, result)
except (AttributeError, KeyError):
continue
elif msg.type == "Mysqlx.Sql.StmtExecuteOk":
return None
elif msg.type == "Mysqlx.Resultset.FetchDone":
result.set_closed(True)
elif msg.type == "Mysqlx.Resultset.FetchDoneMoreResultsets":
result.set_has_more_results(True)
elif msg.type == "Mysqlx.Resultset.Row":
result.set_has_data(True)
break
else:
break
return msg
def set_compression(self, algorithm: str) -> None:
"""Sets the compression algorithm to be used by the compression
object, for uplink and downlink.
Args:
algorithm (str): Algorithm to be used in compression/decompression.
.. versionadded:: 8.0.21
"""
self._compression_algorithm = algorithm
self._reader.set_compression(algorithm)
self._writer.set_compression(algorithm)
def get_capabilites(self) -> MessageType:
"""Get capabilities.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
msg = Message("Mysqlx.Connection.CapabilitiesGet")
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_GET"),
msg,
)
msg = self._reader.read_message()
while msg.type == "Mysqlx.Notice.Frame":
msg = self._reader.read_message()
if msg.type == "Mysqlx.Error":
raise OperationalError(msg["msg"], msg["code"])
return msg
def set_capabilities(self, **kwargs: Any) -> None:
"""Set capabilities.
Args:
**kwargs: Arbitrary keyword arguments.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if not kwargs:
return None
capabilities = Message("Mysqlx.Connection.Capabilities")
for key, value in kwargs.items():
capability = Message("Mysqlx.Connection.Capability")
capability["name"] = key
if isinstance(value, dict):
items = value
obj_flds = []
for item in items:
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=item,
value=self._create_any(items[item]),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
capability["value"] = msg_any.get_message()
else:
capability["value"] = self._create_any(value)
capabilities["capabilities"].extend([capability.get_message()])
msg = Message("Mysqlx.Connection.CapabilitiesSet")
msg["capabilities"] = capabilities
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_SET"),
msg,
)
try:
return self.read_ok()
except InterfaceError as err:
# Skip capability "session_connect_attrs" error since
# is only available on version >= 8.0.16
if err.errno != 5002:
raise
return None
def send_auth_start(
self,
method: str,
auth_data: Optional[str] = None,
initial_response: Optional[str] = None,
) -> None:
"""Send authenticate start.
Args:
method (str): Message method.
auth_data (Optional[str]): Authentication data.
initial_response (Optional[str]): Initial response.
"""
msg = Message("Mysqlx.Session.AuthenticateStart")
msg["mech_name"] = method
if auth_data is not None:
msg["auth_data"] = auth_data
if initial_response is not None:
msg["initial_response"] = initial_response
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_START"),
msg,
)
def read_auth_continue(self) -> bytes:
"""Read authenticate continue.
Raises:
:class:`InterfaceError`: If the message type is not
`Mysqlx.Session.AuthenticateContinue`
Returns:
str: The authentication data.
"""
msg = self._reader.read_message()
while msg.type == "Mysqlx.Notice.Frame":
msg = self._reader.read_message()
if msg.type != "Mysqlx.Session.AuthenticateContinue":
raise InterfaceError(
"Unexpected message encountered during authentication handshake"
)
return msg["auth_data"]
def send_auth_continue(self, auth_data: str) -> None:
"""Send authenticate continue.
Args:
auth_data (str): Authentication data.
"""
msg = Message("Mysqlx.Session.AuthenticateContinue", auth_data=auth_data)
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_CONTINUE"),
msg,
)
def read_auth_ok(self) -> None:
"""Read authenticate OK.
Raises:
:class:`mysqlx.InterfaceError`: If message type is `Mysqlx.Error`.
"""
while True:
msg = self._reader.read_message()
if msg.type == "Mysqlx.Session.AuthenticateOk":
break
if msg.type == "Mysqlx.Error":
raise InterfaceError(msg.msg)
def send_prepare_prepare(
self,
msg_type: str,
msg: MessageType,
stmt: Union[
FindStatement,
DeleteStatement,
ModifyStatement,
ReadStatement,
RemoveStatement,
UpdateStatement,
],
) -> None:
"""
Send prepare statement.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
Raises:
:class:`mysqlx.NotSupportedError`: If prepared statements are not
supported.
.. versionadded:: 8.0.16
"""
if stmt.has_limit and msg.type != "Mysqlx.Crud.Insert":
# Remove 'limit' from message by building a new one
if msg.type == "Mysqlx.Crud.Find":
_, msg = self.build_find(stmt) # type: ignore[arg-type]
elif msg.type == "Mysqlx.Crud.Update":
_, msg = self.build_update(stmt) # type: ignore[arg-type]
elif msg.type == "Mysqlx.Crud.Delete":
_, msg = self.build_delete(stmt) # type: ignore[arg-type]
else:
raise ValueError(f"Invalid message type: {msg_type}")
# Build 'limit_expr' message
position = len(stmt.get_bindings())
placeholder = mysqlxpb_enum("Mysqlx.Expr.Expr.Type.PLACEHOLDER")
msg_limit_expr = Message("Mysqlx.Crud.LimitExpr")
msg_limit_expr["row_count"] = Message(
"Mysqlx.Expr.Expr", type=placeholder, position=position
)
if msg.type == "Mysqlx.Crud.Find":
msg_limit_expr["offset"] = Message(
"Mysqlx.Expr.Expr", type=placeholder, position=position + 1
)
msg["limit_expr"] = msg_limit_expr
oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type]
msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage")
msg_oneof["type"] = mysqlxpb_enum(oneof_type)
msg_oneof[oneof_op] = msg
msg_prepare = Message("Mysqlx.Prepare.Prepare")
msg_prepare["stmt_id"] = stmt.stmt_id
msg_prepare["stmt"] = msg_oneof
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_PREPARE"),
msg_prepare,
)
try:
self.read_ok()
except InterfaceError as err:
raise NotSupportedError from err
def send_prepare_execute(
self, msg_type: str, msg: MessageType, stmt: FilterableStatement
) -> None:
"""
Send execute statement.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
.. versionadded:: 8.0.16
"""
oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type]
msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage")
msg_oneof["type"] = mysqlxpb_enum(oneof_type)
msg_oneof[oneof_op] = msg
msg_execute = Message("Mysqlx.Prepare.Execute")
msg_execute["stmt_id"] = stmt.stmt_id
args = self._get_binding_args(stmt, is_scalar=False)
if args:
msg_execute["args"].extend(args)
if stmt.has_limit:
msg_execute["args"].extend(
[
self._create_any(stmt.get_limit_row_count()).get_message(),
self._create_any(stmt.get_limit_offset()).get_message(),
]
)
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_EXECUTE"),
msg_execute,
)
def send_prepare_deallocate(self, stmt_id: int) -> None:
"""
Send prepare deallocate statement.
Args:
stmt_id (int): Statement ID.
.. versionadded:: 8.0.16
"""
msg_dealloc = Message("Mysqlx.Prepare.Deallocate")
msg_dealloc["stmt_id"] = stmt_id
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_DEALLOCATE"),
msg_dealloc,
)
self.read_ok()
def send_msg_without_ps(
self,
msg_type: str,
msg: MessageType,
stmt: Union[FilterableStatement, SqlStatement],
) -> None:
"""
Send a message without prepared statements support.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
.. versionadded:: 8.0.16
"""
if stmt.has_limit:
msg_limit = Message("Mysqlx.Crud.Limit")
msg_limit["row_count"] = stmt.get_limit_row_count() # type: ignore[union-attr]
if msg.type == "Mysqlx.Crud.Find":
msg_limit["offset"] = stmt.get_limit_offset() # type: ignore[union-attr]
msg["limit"] = msg_limit
is_scalar = msg_type != "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE"
args = self._get_binding_args(stmt, is_scalar=is_scalar)
if args:
msg["args"].extend(args)
self.send_msg(msg_type, msg)
def send_msg(self, msg_type: str, msg: MessageType) -> None:
"""
Send a message.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
self._writer.write_message(mysqlxpb_enum(msg_type), msg)
def build_find(
self, stmt: Union[FindStatement, ReadStatement]
) -> Tuple[str, MessageType]:
"""Build find/read message.
Args:
stmt (Statement): A :class:`mysqlx.ReadStatement` or
:class:`mysqlx.FindStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message("Mysqlx.Crud.Find", data_model=data_model, collection=collection)
if stmt.has_projection:
msg["projection"] = stmt.get_projection_expr()
self._apply_filter(msg, stmt)
if stmt.is_lock_exclusive():
msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.EXCLUSIVE_LOCK")
elif stmt.is_lock_shared():
msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.SHARED_LOCK")
if stmt.lock_contention.value > 0:
msg["locking_options"] = stmt.lock_contention.value
return "Mysqlx.ClientMessages.Type.CRUD_FIND", msg
def build_update(
self, stmt: Union[ModifyStatement, UpdateStatement]
) -> Tuple[str, MessageType]:
"""Build update message.
Args:
stmt (Statement): A :class:`mysqlx.ModifyStatement` or
:class:`mysqlx.UpdateStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message(
"Mysqlx.Crud.Update", data_model=data_model, collection=collection
)
self._apply_filter(msg, stmt)
for _, update_op in stmt.get_update_ops().items():
operation = Message("Mysqlx.Crud.UpdateOperation")
operation["operation"] = update_op.update_type
operation["source"] = update_op.source
if update_op.value is not None:
operation["value"] = build_expr(update_op.value)
msg["operation"].extend([operation.get_message()])
return "Mysqlx.ClientMessages.Type.CRUD_UPDATE", msg
def build_delete(
self, stmt: Union[DeleteStatement, RemoveStatement]
) -> Tuple[str, MessageType]:
"""Build delete message.
Args:
stmt (Statement): A :class:`mysqlx.DeleteStatement` or
:class:`mysqlx.RemoveStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message(
"Mysqlx.Crud.Delete", data_model=data_model, collection=collection
)
self._apply_filter(msg, stmt)
return "Mysqlx.ClientMessages.Type.CRUD_DELETE", msg
def build_execute_statement(
self,
namespace: str,
stmt: Union[str, StatementType],
fields: Optional[Dict[str, Any]] = None,
) -> Tuple[str, MessageType]:
"""Build execute statement.
Args:
namespace (str): The namespace.
stmt (Statement): A `Statement` based type object.
fields (Optional[dict]): The message fields.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
msg = Message(
"Mysqlx.Sql.StmtExecute",
namespace=namespace,
stmt=stmt,
compact_metadata=False,
)
if fields:
obj_flds = []
for key, value in fields.items():
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=key,
value=self._create_any(value),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
msg["args"] = [msg_any.get_message()]
return "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE", msg
@staticmethod
def build_insert(
stmt: Union[AddStatement, InsertStatement]
) -> Tuple[str, MessageType]:
"""Build insert statement.
Args:
stmt (Statement): A :class:`mysqlx.AddStatement` or
:class:`mysqlx.InsertStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message(
"Mysqlx.Crud.Insert", data_model=data_model, collection=collection
)
if hasattr(stmt, "_fields"):
for field in stmt._fields:
expr = ExprParser(
field, not stmt.is_doc_based()
).parse_table_insert_field()
msg["projection"].extend([expr.get_message()])
for value in stmt.get_values():
row = Message("Mysqlx.Crud.Insert.TypedRow")
if isinstance(value, list):
for val in value:
row["field"].extend([build_expr(val).get_message()])
else:
row["field"].extend([build_expr(value).get_message()])
msg["row"].extend([row.get_message()])
if hasattr(stmt, "is_upsert"):
msg["upsert"] = stmt.is_upsert()
return "Mysqlx.ClientMessages.Type.CRUD_INSERT", msg
def close_result(self, result: ResultBaseType) -> None:
"""Close the result.
Args:
result (Result): A `Result` based type object.
Raises:
:class:`mysqlx.OperationalError`: If message read is None.
"""
msg = self._read_message(result)
if msg is not None:
raise OperationalError("Expected to close the result")
def read_row(self, result: ResultBaseType) -> Optional[MessageType]:
"""Read row.
Args:
result (Result): A `Result` based type object.
"""
msg = self._read_message(result)
if msg is None:
return None
if msg.type == "Mysqlx.Resultset.Row":
return msg
self._reader.push_message(msg)
return None
def get_column_metadata(self, result: ResultBaseType) -> List[ColumnType]:
"""Returns column metadata.
Args:
result (Result): A `Result` based type object.
Raises:
:class:`mysqlx.InterfaceError`: If unexpected message.
"""
columns = []
while True:
msg = self._read_message(result)
if msg is None:
break
if msg.type == "Mysqlx.Resultset.Row":
self._reader.push_message(msg)
break
if msg.type != "Mysqlx.Resultset.ColumnMetaData":
raise InterfaceError("Unexpected msg type")
col = Column(
msg["type"],
msg["catalog"],
msg["schema"],
msg["table"],
msg["original_table"],
msg["name"],
msg["original_name"],
msg.get("length", 21),
msg.get("collation", 0),
msg.get("fractional_digits", 0),
msg.get("flags", 16),
msg.get("content_type"),
)
columns.append(col)
return columns
def read_ok(self) -> None:
"""Read OK.
Raises:
:class:`mysqlx.InterfaceError`: If unexpected message.
"""
msg = self._reader.read_message()
if msg.type == "Mysqlx.Error":
raise InterfaceError(f"Mysqlx.Error: {msg['msg']}", errno=msg["code"])
if msg.type != "Mysqlx.Ok":
raise InterfaceError("Unexpected message encountered")
def send_connection_close(self) -> None:
"""Send connection close."""
msg = Message("Mysqlx.Connection.Close")
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CLOSE"), msg
)
def send_close(self) -> None:
"""Send close."""
msg = Message("Mysqlx.Session.Close")
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_CLOSE"), msg
)
def send_expect_open(self) -> None:
"""Send expectation."""
cond_key = mysqlxpb_enum("Mysqlx.Expect.Open.Condition.Key.EXPECT_FIELD_EXIST")
msg_oc = Message("Mysqlx.Expect.Open.Condition")
msg_oc["condition_key"] = cond_key
msg_oc["condition_value"] = "6.1"
msg_eo = Message("Mysqlx.Expect.Open")
msg_eo["cond"] = [msg_oc.get_message()]
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.EXPECT_OPEN"), msg_eo
)
def send_reset(self, keep_open: Optional[bool] = None) -> bool:
"""Send reset session message.
Returns:
boolean: ``True`` if the server will keep the session open,
otherwise ``False``.
"""
msg = Message("Mysqlx.Session.Reset")
if keep_open is None:
try:
# Send expectation: keep connection open
self.send_expect_open()
self.read_ok()
keep_open = True
except InterfaceError:
# Expectation is unkown by this version of the server
keep_open = False
if keep_open:
msg["keep_open"] = True
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_RESET"), msg
)
self.read_ok()
if keep_open:
return True
return False