"""
Salt Factories Engine For Salt.
Simple salt engine which will setup a socket to accept connections allowing us to know
when a daemon is up and running
"""
import asyncio
import atexit
import datetime
import logging
import threading
import time
from collections import deque
from collections.abc import MutableMapping
import salt.utils.event
from salt.utils import immutabletypes
try:
from salt.utils.data import CaseInsensitiveDict
except ImportError: # pragma: no cover
CaseInsensitiveDict = None
try:
import msgpack
HAS_MSGPACK = True
except ImportError: # pragma: no cover
HAS_MSGPACK = False
log = logging.getLogger(__name__)
__virtualname__ = "pytest"
def __virtual__():
if HAS_MSGPACK is False: # pragma: no cover
return False, "msgpack was not importable. Please install msgpack."
if "__role" not in __opts__: # pragma: no cover
return False, "The required '__role' key could not be found in the options dictionary"
role = __opts__["__role"]
pytest_key = "pytest-{}".format(role)
if pytest_key not in __opts__: # pragma: no cover
return False, "No '{}' key in opts dictionary".format(pytest_key)
pytest_config = __opts__[pytest_key]
if "returner_address" not in pytest_config:
return False, "No 'returner_address' key in opts['{}'] dictionary".format(pytest_key)
return True
[docs]
def start():
"""
Method to start the engine.
"""
opts = __opts__ # pylint: disable=undefined-variable
try:
pytest_engine = PyTestEventForwardEngine(opts=opts)
pytest_engine.start()
except Exception: # pragma: no cover pylint: disable=broad-except
log.exception("Failed to start PyTestEventForwardEngine")
raise
[docs]
def ext_type_encoder(obj):
"""
Convert any types that msgpack cannot handle on it's own.
"""
if isinstance(obj, (datetime.datetime, datetime.date)):
# msgpack doesn't support datetime.datetime and datetime.date datatypes.
return obj.strftime("%Y%m%dT%H:%M:%S.%f")
# The same for immutable types
if immutabletypes is not None:
if isinstance(obj, immutabletypes.ImmutableDict):
return dict(obj)
if isinstance(obj, immutabletypes.ImmutableList):
return list(obj)
if isinstance(obj, immutabletypes.ImmutableSet):
# msgpack can't handle set so translate it to tuple
return tuple(obj)
if isinstance(obj, set):
# msgpack can't handle set so translate it to tuple
return tuple(obj)
if CaseInsensitiveDict is not None and isinstance(obj, CaseInsensitiveDict):
return dict(obj)
if isinstance(obj, MutableMapping):
return dict(obj)
# Nothing known exceptions found. Let msgpack raise its own.
return obj
[docs]
class PyTestEventForwardClient(asyncio.Protocol):
"""
TCP Client to forward events.
"""
def __init__(self, queue, client_running_event):
self.queue = queue
self.running = client_running_event
self.task = None
self.transport = None
try:
loop = asyncio.get_running_loop()
except AttributeError:
# Python < 3.7
loop = asyncio.get_event_loop()
self._connected = loop.create_future()
self._disconnected = loop.create_future()
[docs]
def connection_made(self, transport):
"""
Connection established.
"""
peername = transport.get_extra_info("peername")
log.debug("%s: Connected to %s", self.__class__.__name__, peername)
self._connected.set_result(True)
# pylint: disable=attribute-defined-outside-init
self.transport = transport
try:
loop = asyncio.get_running_loop()
except AttributeError:
# Python < 3.7
loop = asyncio.get_event_loop()
self.task = loop.create_task(self._process_queue())
# pylint: enable=attribute-defined-outside-init
[docs]
def connection_lost(self, exc): # noqa: ARG002
"""
Connection lost.
"""
log.debug("%s: The server closed the connection", self.__class__.__name__)
self._disconnected.set_result(True)
if self.task is not None:
self.task.cancel()
[docs]
async def wait_connected(self):
"""
Wait until a connection to the server is successful.
"""
return await self._connected
[docs]
async def wait_disconnected(self):
"""
Wait until disconnected from the server.
"""
return await self._disconnected
async def _process_queue(self):
self.running.set()
log.info("%s: Now processing the queue", self.__class__.__name__)
restarts = 0
max_restarts = 10
while True:
if restarts > max_restarts:
self._disconnected.set_result(True)
break
if not self.running.is_set():
self._disconnected.set_result(True)
break
try:
try:
payload = self.queue.popleft()
except IndexError:
await asyncio.sleep(1)
continue
if payload is None:
return
dumped = msgpack.packb(payload, use_bin_type=True, default=ext_type_encoder)
self.transport.write(dumped)
log.debug("%s: forwarded event: %r", self.__class__.__name__, payload)
except asyncio.CancelledError:
break
except Exception: # pylint: disable=broad-except
log.exception(
"%s: Caught exception while pulling data from queue",
self.__class__.__name__,
)
restarts += 1
[docs]
class PyTestEventForwardEngine:
"""
Salt Engine instance.
"""
__slots__ = (
"opts",
"id",
"role",
"returner_address_host",
"returner_address_port",
"running_event",
"client_running_event",
"loop",
"client",
"queue",
"running_thread",
)
def __init__(self, opts):
self.opts = opts
self.id = self.opts["id"] # pylint: disable=invalid-name
self.role = self.opts["__role"]
returner_address = self.opts["pytest-{}".format(self.role)]["returner_address"]
self.returner_address_host = returner_address["host"]
self.returner_address_port = returner_address["port"]
self.running_event = threading.Event()
self.client_running_event = threading.Event()
self.loop = asyncio.new_event_loop()
self.client = None
self.queue = deque(maxlen=1000)
self.running_thread = threading.Thread(target=self._run_loop_in_thread, args=(self.loop,))
def _run_loop_in_thread(self, loop):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self._run_client(loop))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
async def _run_client(self, loop):
log.debug(
"%s client connecting to %s:%s",
self.__class__.__name__,
self.returner_address_host,
self.returner_address_port,
)
self.client = PyTestEventForwardClient(self.queue, self.client_running_event)
transport, _ = await loop.create_connection(
lambda: self.client,
self.returner_address_host,
self.returner_address_port,
)
# Wait until the protocol signals that the connection
# is lost and close the transport.
try:
await asyncio.wait_for(self.client.wait_connected(), timeout=15)
except asyncio.TimeoutError:
log.error("The client failed to connect to the server after 15 seconds") # noqa: TRY400
transport.close()
else:
try:
log.info("%s client started", self.__class__.__name__)
await self.client.wait_disconnected()
finally:
transport.close()
def __repr__(self): # noqa: D105
return "<{} role={!r} id={!r}, returner_address='{}:{}' running={!r}>".format(
self.__class__.__name__,
self.role,
self.id,
self.returner_address_host,
self.returner_address_port,
self.running_event.is_set(),
)
[docs]
def start(self):
"""
Start the engine.
"""
if self.running_event.is_set():
return
log.info("%s is starting", self)
atexit.register(self.stop)
self.running_event.set()
self.running_thread.start()
timeout_at = time.time() + 10
while True:
log.info("Waiting for %s.client to start...", self.__class__.__name__)
if time.time() > timeout_at:
msg = "Failed to start client"
raise RuntimeError(msg)
if self.client is None:
time.sleep(1)
continue
if not self.client_running_event.is_set():
time.sleep(1)
continue
break
try:
opts = self.opts.copy()
opts["file_client"] = "local"
with salt.utils.event.get_event(
self.role,
sock_dir=opts["sock_dir"],
opts=opts,
listen=True,
) as eventbus:
if self.role == "master":
event_tag = "salt/master/{}/start".format(self.id)
log.info("%s firing event on engine start. Tag: %s", self, event_tag)
load = {"id": self.id, "tag": event_tag, "data": {}}
eventbus.fire_event(load, event_tag)
log.info("%s started", self)
while self.running_event.is_set():
for event in eventbus.iter_events(full=True, auto_reconnect=True):
if not event:
continue
tag = event["tag"]
data = event["data"]
log.debug("%s Received Event; TAG: %r DATA: %r", self, tag, data)
forward = {"id": self.id, "tag": tag, "data": data}
self.queue.append(forward)
finally:
if self.running_event.is_set():
# Some exception happened, unset
self.running_event.clear()
[docs]
def stop(self):
"""
Stop the engine.
"""
if self.running_event.is_set() is False:
return
log.info("Stopping %s", self)
self.running_event.clear()
self.queue.append(None)
self.client_running_event.clear()
self.running_thread.join()
log.info("%s stopped", self)