"""
Z21 Station controller.
Main class for communicating with a Z21 DCC command station over UDP.
Provides async methods for station control and locomotive operations.
"""
from __future__ import annotations
import asyncio
import logging
import struct
from collections.abc import Callable
from typing import Any
from z21aio.headers import get_header_name
from .packet import Packet
from .messages import (
LAN_DISCOVER_DEVICES,
LAN_GET_SERIAL_NUMBER,
LAN_LOGOFF,
LAN_XBUS_HEADER,
LAN_SET_BROADCASTFLAGS,
LAN_SYSTEMSTATE_DATACHANGED,
LAN_SYSTEMSTATE_GETDATA,
LAN_RAILCOM_DATACHANGED,
LAN_RAILCOM_GETDATA,
XBUS_GET_VERSION_REPLY,
XBUS_GET_FIRMWARE_VERSION_REPLY,
XBUS_BC_TRACK_POWER,
XBUS_BC_TRACK_POWER_ON_DB0,
BROADCAST_LOCO_INFO,
BROADCAST_RAILCOM_SUBSCRIBED,
BROADCAST_RAILCOM_ALL,
XBusMessage,
)
from .types import SystemState, RailComData, LocoState
log = logging.getLogger(__name__)
DEFAULT_PORT = 21105
DEFAULT_TIMEOUT = 2.0
KEEP_ALIVE_INTERVAL = 20.0
BUFFER_SIZE = 1024
QUEUE_MAX_SIZE = 10
class Z21Protocol(asyncio.DatagramProtocol):
"""UDP protocol handler for Z21 communication."""
def __init__(self, station: Z21Station) -> None:
log.debug("Initializing Z21Protocol")
self._station = station
self._transport: asyncio.DatagramTransport | None = None
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore[override]
log.debug("UDP connection established")
self._transport = transport
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
offset = 0
while offset < len(data):
try:
packet = Packet.from_bytes(data[offset:])
log.debug(
"datagram_received: %d bytes from %s %s", len(data), addr, packet
)
self._station._handle_packet(packet)
offset += packet.data_len
except (ValueError, struct.error) as e:
log.error(
"Failed to parse packet from %s at offset %d: %s <%s>",
addr,
offset,
e,
data[offset:].hex(" "),
)
break
def error_received(self, exc: Exception) -> None:
log.error("Protocol error: %s", exc)
def connection_lost(self, exc: Exception | None) -> None:
if exc:
log.error("Connection lost with error: %s", exc)
else:
log.debug("Connection closed")
self._station._connection_lost = True
[docs]
class Z21Station:
"""
Z21 DCC command station controller.
Provides async methods for communicating with a Z21 station over UDP.
Supports multiple simultaneous connections to different stations.
Example:
async with await Z21Station.connect("192.168.0.111") as station:
await station.voltage_on()
serial = await station.get_serial_number()
print(f"Serial: {serial}")
"""
def __init__(self) -> None:
self._transport: asyncio.DatagramTransport | None = None
self._protocol: Z21Protocol | None = None
self._timeout: float = DEFAULT_TIMEOUT
self._host: str = ""
self._port: int = DEFAULT_PORT
self._keep_alive_task: asyncio.Task[None] | None = None
self._running: bool = False
self._connection_lost: bool = False
# Packet routing
self._packet_waiters: dict[int, asyncio.Queue[Packet]] = {}
self._subscribers: dict[int, list[Callable[[Packet], None]]] = {}
self._broadcast_flags: int = BROADCAST_LOCO_INFO
[docs]
@classmethod
async def connect(
cls,
host: str,
port: int = DEFAULT_PORT,
timeout: float = DEFAULT_TIMEOUT,
*,
keep_alive: bool = True,
) -> Z21Station:
"""
Connect to a Z21 station.
Args:
host: IP address of the Z21 station
port: UDP port (default 21105)
timeout: Command timeout in seconds (default 2.0)
keep_alive: Start the 20-second keep-alive loop automatically
(default True).
Returns:
Connected Z21Station instance
"""
station = cls()
station._host = host
station._port = port
station._timeout = timeout
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_datagram_endpoint(
lambda: Z21Protocol(station),
remote_addr=(host, port),
)
station._transport = transport
station._protocol = protocol
station._running = True
if keep_alive:
station._keep_alive_task = asyncio.create_task(station._keep_alive_loop())
await station._set_broadcast_flags(station._broadcast_flags)
return station
async def _keep_alive_loop(self) -> None:
"""Background task to send keep-alive packets."""
while self._running:
try:
await asyncio.sleep(KEEP_ALIVE_INTERVAL)
if self._running and not self._connection_lost:
await self._set_broadcast_flags(self._broadcast_flags)
except asyncio.CancelledError:
break
except (OSError, ConnectionError) as e:
log.debug("Keep-alive failed: %s", e)
[docs]
def start_keep_alive(self) -> None:
"""Start the keep-alive background task if not already running."""
if self._keep_alive_task is None or self._keep_alive_task.done():
self._keep_alive_task = asyncio.create_task(self._keep_alive_loop())
[docs]
async def stop_keep_alive(self) -> None:
"""Cancel the keep-alive background task if running."""
if self._keep_alive_task is not None:
self._keep_alive_task.cancel()
try:
await self._keep_alive_task
except asyncio.CancelledError:
pass
self._keep_alive_task = None
async def _set_broadcast_flags(self, flags: int) -> None:
"""Set broadcast flags to control which events we receive."""
data = struct.pack("<I", flags)
packet = Packet.with_header_and_data(LAN_SET_BROADCASTFLAGS, data)
await self.send_packet(packet)
def _handle_packet(self, packet: Packet) -> None:
"""Route incoming packet to waiters and subscribers."""
header = packet.header
if header == LAN_XBUS_HEADER:
header = XBusMessage.from_bytes(packet.data).x_header
if header in self._packet_waiters:
try:
self._packet_waiters[header].put_nowait(packet)
except asyncio.QueueFull:
self._packet_waiters[header].shutdown(True)
self._packet_waiters[header] = asyncio.Queue(maxsize=QUEUE_MAX_SIZE)
self._packet_waiters[header].put_nowait(packet)
if header in self._subscribers:
for callback in self._subscribers[header]:
try:
callback(packet)
except (TypeError, ValueError) as e:
log.error(
"Got callback error from subscribed header=%s %s",
get_header_name(header),
e,
)
[docs]
async def send_packet(self, packet: Packet) -> None:
"""
Send a packet to the Z21 station.
Args:
packet: Packet to send
"""
log.debug("Sending packet: %s", packet)
if self._transport is None:
raise ConnectionError("Not connected to Z21 station")
self._transport.sendto(packet.to_bytes())
[docs]
async def receive_packet(
self,
header: int,
timeout: float | None = None,
) -> Packet:
"""
Wait for a packet with the specified header.
Args:
header: Expected packet header
timeout: Timeout in seconds (uses default if None)
Returns:
Received packet
Raises:
asyncio.TimeoutError: If no packet received within timeout
"""
if timeout is None:
timeout = self._timeout
if header not in self._packet_waiters:
self._packet_waiters[header] = asyncio.Queue(maxsize=QUEUE_MAX_SIZE)
queue = self._packet_waiters[header]
try:
return await asyncio.wait_for(queue.get(), timeout=timeout)
except asyncio.TimeoutError:
raise asyncio.TimeoutError(
f"Timeout waiting for packet with header 0x{header:04X}"
)
[docs]
async def send_xbus_command(
self,
msg: XBusMessage,
expected_response_header: int | None = None,
) -> XBusMessage | None:
"""
Send an XBus command and optionally wait for response.
Args:
msg: XBus message to send
expected_response_header: XBus header to wait for (None = no wait)
Returns:
Response XBusMessage if expected_response_header specified, else None
"""
packet = Packet.with_header_and_data(LAN_XBUS_HEADER, msg.to_bytes())
await self.send_packet(packet)
if expected_response_header is not None:
response_packet = await self.receive_packet(expected_response_header)
return XBusMessage.from_bytes(response_packet.data)
return None
[docs]
async def get_serial_number(self) -> int:
"""
Get the Z21 station serial number.
Returns:
Station serial number as integer
"""
packet = Packet.with_header(LAN_GET_SERIAL_NUMBER)
await self.send_packet(packet)
response = await self.receive_packet(LAN_GET_SERIAL_NUMBER)
if len(response.data) < 4:
raise ValueError("Invalid serial number response")
return struct.unpack("<I", response.data[:4])[0]
[docs]
async def discover_devices(self) -> None:
"""
Get the Z21 station devices.
Returns:
None
"""
packet = Packet.with_header(LAN_DISCOVER_DEVICES)
await self.send_packet(packet)
[docs]
async def get_firmware_version(self) -> tuple[int, int]:
"""
Get the Z21 station firmware version.
Returns:
Tuple of (major, minor) version in BCD format.
For example, (1, 30) represents firmware version 1.30
"""
msg = XBusMessage.get_firmware_version()
response = await self.send_xbus_command(
msg, expected_response_header=XBUS_GET_FIRMWARE_VERSION_REPLY
)
if response is None or len(response.dbs) < 3:
raise ValueError("Invalid firmware version response")
# Response format: DB0=0x0A, DB1=V_MSB (BCD), DB2=V_LSB (BCD)
v_msb = response.dbs[1] # Major version in BCD
v_lsb = response.dbs[2] # Minor version in BCD
return (v_msb, v_lsb)
[docs]
async def get_version(self) -> tuple[int, int]:
"""
Get the X-BUS protocol version and command station ID.
Returns:
Tuple of (xbus_version, command_station_id).
xbus_version is in BCD format (e.g., 0x36 = version 3.6).
command_station_id identifies the type of command station.
"""
msg = XBusMessage.get_version()
response = await self.send_xbus_command(
msg, expected_response_header=XBUS_GET_VERSION_REPLY
)
if response is None or len(response.dbs) < 3:
raise ValueError("Invalid version response")
# Response format: DB0=0x21, DB1=X-BUS Version, DB2=Command Station ID
xbus_version = response.dbs[1]
command_station_id = response.dbs[2]
return (xbus_version, command_station_id)
[docs]
async def voltage_on(self) -> None:
"""Turn on track power."""
msg = XBusMessage.track_power_on()
await self.send_xbus_command(msg)
[docs]
async def voltage_off(self) -> None:
"""Turn off track power (emergency stop all locomotives)."""
msg = XBusMessage.track_power_off()
await self.send_xbus_command(msg)
[docs]
async def logout(self) -> None:
"""Send logout/disconnect command to Z21."""
packet = Packet.with_header(LAN_LOGOFF)
await self.send_packet(packet)
[docs]
def subscribe_system_state(
self,
callback: Callable[[SystemState], None],
freq_hz: float = 1.0,
) -> asyncio.Task[None]:
"""
Subscribe to system state updates.
Args:
callback: Function called with SystemState on each update
freq_hz: Polling frequency in Hz (default 1.0)
Returns:
Background task handle (can be cancelled)
"""
async def poll_loop() -> None:
interval = 1.0 / freq_hz if freq_hz > 0 else 1.0
while self._running:
try:
# Request system state
packet = Packet.with_header(LAN_SYSTEMSTATE_GETDATA)
await self.send_packet(packet)
await asyncio.sleep(interval)
except asyncio.CancelledError:
break
except (OSError, ConnectionError) as e:
log.debug("System state poll error: %s", e)
def handle_state(packet: Packet) -> None:
try:
state = SystemState.from_bytes(packet.data)
callback(state)
except (ValueError, TypeError) as e:
log.debug("Invalid system state packet or callback error: %s", e)
# Subscribe to state change packets
if LAN_SYSTEMSTATE_DATACHANGED not in self._subscribers:
self._subscribers[LAN_SYSTEMSTATE_DATACHANGED] = []
self._subscribers[LAN_SYSTEMSTATE_DATACHANGED].append(handle_state)
# Start polling task
return asyncio.create_task(poll_loop())
[docs]
def subscribe_track_power(
self,
callback: Callable[[bool], None],
) -> None:
"""
Subscribe to track power state change broadcasts.
The callback is called whenever the track power state changes,
whether triggered by this client or an external device (e.g., multiMaus).
Requires broadcast flag 0x00000001, which is set by default.
Args:
callback: Called with True when track power turns on,
False when it turns off.
"""
def handle_packet(packet: Packet) -> None:
try:
msg = XBusMessage.from_bytes(packet.data)
if len(msg.dbs) < 1:
return
is_on = msg.dbs[0] == XBUS_BC_TRACK_POWER_ON_DB0
callback(is_on)
except (ValueError, TypeError) as e:
log.error("Error in track power subscription callback: %s", e)
if XBUS_BC_TRACK_POWER not in self._subscribers:
self._subscribers[XBUS_BC_TRACK_POWER] = []
self._subscribers[XBUS_BC_TRACK_POWER].append(handle_packet)
[docs]
async def enable_railcom_broadcasts(self, all_locos: bool = False) -> None:
"""
Enable RailCom data broadcasts.
Args:
all_locos: If True, receive RailCom data for all locos.
If False (default), only receive data for subscribed locos.
Note:
all_locos=True requires firmware 1.29+
"""
if all_locos:
self._broadcast_flags |= BROADCAST_RAILCOM_ALL
else:
self._broadcast_flags |= BROADCAST_RAILCOM_SUBSCRIBED
await self._set_broadcast_flags(self._broadcast_flags)
[docs]
async def disable_railcom_broadcasts(self) -> None:
"""Disable RailCom data broadcasts."""
self._broadcast_flags &= ~(BROADCAST_RAILCOM_SUBSCRIBED | BROADCAST_RAILCOM_ALL)
await self._set_broadcast_flags(self._broadcast_flags)
[docs]
async def get_railcom_data(
self,
address: int | None = None,
timeout: float | None = None,
) -> RailComData:
"""
Request RailCom data for a specific locomotive or next in queue.
Args:
address: DCC address to query, or None for circular polling
timeout: Response timeout in seconds (uses default if None)
Returns:
RailComData for the queried locomotive
Raises:
asyncio.TimeoutError: If no response within timeout
Note:
Requires firmware 1.29+
"""
if address is None:
# Type 0x00 = poll next in circular queue
data = struct.pack("<BH", 0x00, 0x0000)
else:
# Type 0x01 = poll specific address
data = struct.pack("<BH", 0x01, address)
packet = Packet.with_header_and_data(LAN_RAILCOM_GETDATA, data)
await self.send_packet(packet)
response = await self.receive_packet(LAN_RAILCOM_DATACHANGED, timeout)
return RailComData.from_bytes(response.data)
[docs]
def subscribe_railcom(
self,
callback: Callable[[RailComData], None],
address: int | None = None,
) -> None:
"""
Subscribe to RailCom data broadcasts.
Args:
callback: Function called with RailComData on each update
address: If specified, filter for this address only.
If None, receive all RailCom broadcasts.
Note:
Call enable_railcom_broadcasts() first to receive broadcasts.
"""
def handle_railcom(packet: Packet) -> None:
try:
railcom_data = RailComData.from_bytes(packet.data)
if address is None or railcom_data.loco_address == address:
callback(railcom_data)
except (ValueError, TypeError) as e:
log.error("Error handling RailCom packet: %s", e)
if LAN_RAILCOM_DATACHANGED not in self._subscribers:
self._subscribers[LAN_RAILCOM_DATACHANGED] = []
self._subscribers[LAN_RAILCOM_DATACHANGED].append(handle_railcom)
[docs]
def subscribe_railcom_polled(
self,
callback: Callable[[RailComData], None],
address: int | None = None,
freq_hz: float = 1.0,
) -> asyncio.Task[None]:
"""
Subscribe to RailCom data via polling.
Polls the Z21 at the specified frequency for RailCom data.
Useful when broadcast flags cannot be changed or for specific addresses.
Args:
callback: Function called with RailComData on each poll
address: Specific address to poll, or None for circular polling
freq_hz: Polling frequency in Hz (default 1.0)
Returns:
Background task handle (can be cancelled)
"""
async def poll_loop() -> None:
interval = 1.0 / freq_hz if freq_hz > 0 else 1.0
while self._running:
try:
data = await self.get_railcom_data(address)
callback(data)
await asyncio.sleep(interval)
except asyncio.CancelledError:
break
except asyncio.TimeoutError:
# No response - may be no RailCom-capable decoder
await asyncio.sleep(interval)
except (OSError, ConnectionError) as e:
log.error("RailCom poll error: %s", e)
await asyncio.sleep(interval)
return asyncio.create_task(poll_loop())
[docs]
def subscribe_loco_state(
self,
callback: Callable[[LocoState], None],
) -> None:
"""
Subscribe to locomotive state updates from all locomotives.
The callback will be called whenever the station broadcasts
a state update for any locomotive. The LocoState object
includes the locomotive address, speed, functions, and other state.
Args:
callback: Function called with LocoState for each update.
Receives updates for ALL locomotives.
Example:
def on_any_loco_state(state: LocoState):
print(f"Loco {state.address}: speed={state.speed_percentage}%")
station.subscribe_loco_state(on_any_loco_state)
"""
from .messages import XBUS_LOCO_INFO
def handle_packet(packet: Packet) -> None:
try:
xbus_msg = XBusMessage.from_bytes(packet.data)
if xbus_msg.x_header == XBUS_LOCO_INFO:
state = LocoState.from_bytes(xbus_msg.dbs)
callback(state)
except (ValueError, TypeError) as e:
log.error("Error in loco state subscription callback: %s", e)
# Register subscriber for XBUS_LOCO_INFO header
if XBUS_LOCO_INFO not in self._subscribers:
self._subscribers[XBUS_LOCO_INFO] = []
self._subscribers[XBUS_LOCO_INFO].append(handle_packet)
[docs]
async def close(self) -> None:
"""
Close the connection and clean up resources.
Stops keep-alive task, sends logout, and closes transport.
"""
self._running = False
# Cancel keep-alive task
if self._keep_alive_task is not None:
self._keep_alive_task.cancel()
try:
await self._keep_alive_task
except asyncio.CancelledError as e:
log.warning(f"Error while canceling keep alive task error={e}")
try:
await self.logout()
except (OSError, ConnectionError) as e:
log.warning("Error while logging out error=%s", e)
if self._transport is not None:
self._transport.close()
self._transport = None
[docs]
async def __aenter__(self) -> Z21Station:
return self
[docs]
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
await self.close()
def __repr__(self) -> str:
status = "connected" if self._running else "disconnected"
return f"Z21Station({self._host}:{self._port}, {status})"