from __future__ import annotations
import os
import threading
import traceback
import grpc
from enum import Enum, unique
from time import sleep
from .test_runner_protoc import test_runner_pb2 as pb
from .test_runner_protoc import test_runner_pb2_grpc as pbgrpc
from dataclasses import dataclass
from datetime import timedelta
import sys
class TestRunner:
def __init__(self, port):
self._channel = grpc.insecure_channel(f'localhost:{port}')
self._stub = pbgrpc.TestRunnerServiceStub(self._channel)
return
def __del__(self):
self._channel.close()
def start_test(self, folder_name:str):
self._stub.StartTest(pb.StartTestRequest(folder_name=folder_name))
def stop_test(self):
self._stub.StopTest(pb.Empty())
def send_serial(self,port_number,serial_data):
self._stub.SendSerial(pb.SendSerialRequest(port_number=port_number, serial_data=serial_data))
def send_serial_command(self,port_number,serial_command):
self._stub.SendSerialCommand(pb.SendSerialCommandRequest(port_number=port_number, serial_command=serial_command))
def get_latest_serial(self,port_number,direction):
res = self._stub.GetLatestSerialData(pb.GetLatestSerialDataRequest(port_number=port_number,direction=direction))
return res.serialData
def wait_for_next_serial_data(self, port_number, timeout_ms):
if timeout_ms is None:
timeout_ms = -1
elif timeout_ms <= 0:
raise ValueError("Cannot set up the 'timeout' 0 or less.")
res = self._stub.WaitForNextSerialData(pb.WaitForNextSerialDataRequest(port_number=port_number, timeout=timeout_ms))
return res.serialData
def write_logic_rpigp10(self, board_id: str, channel_number: int, state: bool):
self._stub.WriteLogicRPiGP10(pb.WriteLogicRPiGP10Request(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), state=state))
def write_logic_rpigp10_multi(self, board_id: str, states: list):
values = [-1 if b is None else int(b) for b in states]
self._stub.WriteLogicRPiGP10Multi(pb.WriteLogicRPiGP10MultiRequest(board_id=board_id, states=values))
def write_logic_rpigp10_bits(self, board_id: str, state: int):
self._stub.WriteLogicRPiGP10Bits(pb.WriteLogicRPiGP10BitsRequest(board_id=board_id, state=state))
def read_logic_rpigp10(self, board_id: str, channel_number: int):
res = self._stub.ReadLogicRPiGP10(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
return res.state
def write_analog_amao(self, board_id: str, channel_number: int, voltage_mv: int):
self._stub.WriteAnalogAMAO(pb.WriteAnalogAMAORequest(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), voltage=voltage_mv))
def read_analog_rpigp40(self, board_id: str, channel_number: int):
res = self._stub.ReadAnalogRPiGP40(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
return res.voltage
def write_pwm_ampio(self, board_id: str, channel_number: int, frequency_hz: int, duty_percent: float):
self._stub.WritePwmAMPIO(pb.WritePwmAMPIORequest(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), frequency=frequency_hz, duty=duty_percent))
def read_pwm_ampio(self, board_id: str, channel_number: int):
res = self._stub.ReadPwmAMPIO(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
return res.frequency, res.duty
def read_analog_cpiai1208li(self, board_id: str, channel_number: int):
res = self._stub.ReadAnalogCPIAI1208LI(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
return res.voltage
def write_logic_cpirry16(self, board_id: str, channel_number: int, state: bool):
self._stub.WriteLogicCpiRry16(pb.WriteLogicCpiRry16Request(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), state=state))
def write_logic_cpirry16_multi(self, board_id: str, states: list):
values = [-1 if b is None else int(b) for b in states]
self._stub.WriteLogicCpiRry16Multi(pb.WriteLogicCpiRry16MultiRequest(board_id=board_id, states=values))
def write_logic_cpirry16_bits(self, board_id: str, state: int):
self._stub.WriteLogicCpiRry16Bits(pb.WriteLogicCpiRry16BitsRequest(board_id=board_id, state=state))
def write_logic_cpidio0808l(self, board_id: str, channel_number: int, state: bool):
self._stub.WriteLogicCpiDio0808L(pb.WriteLogicCpiDio0808LRequest(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), state=state))
def write_logic_cpidio0808l_multi(self, board_id: str, states: list):
values = [-1 if b is None else int(b) for b in states]
self._stub.WriteLogicCpiDio0808LMulti(pb.WriteLogicCpiDio0808LMultiRequest(board_id=board_id, states=values))
def write_logic_cpidio0808l_bits(self, board_id: str, state: int):
self._stub.WriteLogicCpiDio0808LBits(pb.WriteLogicCpiDio0808LBitsRequest(board_id=board_id, state=state))
def read_logic_cpidio0808l(self, board_id: str, channel_number: int):
res = self._stub.ReadLogicCpiDio0808L(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
return res.state
def send_can(self,user_defined_name, channel, can_id, can_data):
self._stub.SendCan(pb.SendCanRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, can_data=can_data))
def send_can_message(self, user_defined_name, channel, can_message_name):
self._stub.SendCanMessage(pb.SendCanMessageRequest(user_defined_name=user_defined_name, channel=channel, can_message_name=can_message_name))
def get_latest_can(self,user_defined_name, channel, can_id, direction):
res = self._stub.GetLatestCanMessage(pb.GetLatestCanMessageRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, direction=direction))
return res.can_data
def wait_for_next_can_message(self,user_defined_name, channel, can_id, timeout_ms):
if timeout_ms is None:
timeout_ms = -1
elif timeout_ms <= 0:
raise ValueError("Cannot set up the 'timeout' 0 or less.")
res = self._stub.WaitForNextCanMessage(pb.WaitForNextCanMessageRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, timeout=timeout_ms))
return res.can_data
def start_can_periodic(self,user_defined_name, channel, can_id, can_data, period):
self._stub.StartCanPeriodic(pb.StartCanPeriodicRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, can_data=can_data, period=period))
def stop_can_periodic(self,user_defined_name, channel, can_id):
if can_id is None:
can_id = -1
elif can_id < 0:
raise ValueError("Cannot set up the 'can_id' less than 0.")
self._stub.StopCanPeriodic(pb.StopCanPeriodicRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id))
def start_can_message_periodic(self, user_defined_name, can_message_name):
self._stub.StartCanMessagePeriodic(pb.StartCanMessagePeriodicRequest(user_defined_name=user_defined_name, can_message_name=can_message_name))
def stop_can_message_periodic(self, user_defined_name, can_message_name):
self._stub.StopCanMessagePeriodic(pb.StopCanMessagePeriodicRequest(user_defined_name=user_defined_name, can_message_name=can_message_name))
def send_udp(self, endpoint_name, remote_ip, remote_port, bytes_data):
self._stub.SendUDP(pb.SendUDPRequest(endpoint_name=endpoint_name, remote_ip=remote_ip, remote_port=remote_port, udp_data=bytes_data))
def get_latest_udp(self, endpoint_name, direction):
res = self._stub.GetLatestUDPData(pb.GetLatestUDPDataRequest(endpoint_name=endpoint_name, direction=direction))
return res.udp_data
def wait_for_next_udp_data(self, endpoint_name, timeout_ms):
if timeout_ms is None:
timeout_ms = -1
elif timeout_ms <= 0:
raise ValueError("Cannot set up the 'timeout' 0 or less.")
res = self._stub.WaitForNextUDPData(pb.WaitForNextUDPDataRequest(endpoint_name=endpoint_name, timeout=timeout_ms))
return res.udp_data
def start_mg400_project(self, mg400_name: str, project_name: str) -> MG400ProjectThread:
thread = MG400ProjectThread(
target=lambda: self._stub.RunMG400Project(pb.RunMG400ProjectRequest(mg400_name=mg400_name, project_name=project_name))
)
thread.start()
max_retries = 100 # 100*0.1 = 10 seconds
retries=0
while self.get_mg400_state(mg400_name) == MG400State.ENABLE:
if not thread.is_alive():
raise thread.exception if thread.exception else RuntimeError("Failed to start MG400 project.")
if retries >= max_retries:
raise TimeoutError(f"MG400 project did not start within {max_retries * 0.1} seconds.")
sleep(0.1)
retries += 1
return thread
def stop_mg400_project(self, mg400_name: str):
self._stub.StopMG400Project(pb.StopMG400ProjectRequest(mg400_name=mg400_name))
def get_mg400_state(self, mg400_name: str):
res = self._stub.GetMG400State(pb.GetMG400StateRequest(mg400_name=mg400_name))
return MG400State(res.state)
def load_vd_config(self, config_file_path: str):
"""Connected VD の設定ファイルをロードしてサーバーを起動する。
Args:
config_file_path: 設定ファイルのパス
"""
self._stub.LoadCoVDConfig(pb.LoadCoVDConfigRequest(config_file_path=config_file_path))
def update_vd_config(self, **kwargs):
"""Connected VD の設定を部分的に更新する。
Args:
**kwargs: 更新したいパラメータをキーワード引数で指定
例:
update_vd_config(
input_media_path="C:/videos/test.mp4",
output_path="C:/results"
)
"""
# Python の snake_case キーを文字列値に変換
config_map = {}
for key, value in kwargs.items():
# リストや辞書は JSON シリアライズ
if isinstance(value, (list, dict)):
import json
config_map[key] = json.dumps(value)
else:
config_map[key] = str(value)
self._stub.UpdateCoVDConfig(pb.UpdateCoVDConfigRequest(config=config_map))
def start_vd_judge(self):
"""Connected VD の判定を開始する(設定ロード済み前提)。
Returns:
判定完了を待機するスレッド
"""
self._stub.StartCoVDJudge(pb.Empty())
thread = VDThread(
target=lambda: self._stub.WaitCoVDJudgeFinish(pb.Empty())
)
thread.start()
return thread
def stop_vd_judge(self):
self._stub.StopCoVDJudge(pb.Empty())
def get_vd_log_list(self):
res = self._stub.GetCoVDLogList(pb.Empty())
return list(res.log_items)
STREAM_OPEN_MESSAGE = "STREAM_OPEN"
def _open_stream(self, open_method, stream_type: str, max_retries: int = 100):
"""ストリームを開く共通処理。
Args:
open_method: gRPCの開始メソッド (例: self._stub.OpenSerialStream)
stream_type: ストリームタイプ名(エラーメッセージ用)
max_retries: 最大リトライ回数(デフォルト: 100)
Returns:
開かれたストリーム
Raises:
TimeoutError: タイムアウト時
"""
res = open_method(pb.Empty())
retries = 0
for message in res:
try:
decoded = message.data.decode("utf-8")
except UnicodeDecodeError:
# デコードに失敗した場合は、STREAM_OPEN_MESSAGE でないとみなして続行
decoded = None
if decoded == self.STREAM_OPEN_MESSAGE:
break
# STREAM_OPENメッセージが来るまで最大max_retriesメッセージ受信
retries += 1
if retries >= max_retries:
raise TimeoutError(f"{stream_type} stream open timed out after {max_retries} messages.")
return res
def _close_stream(self, stream, close_method):
"""ストリームを閉じる共通処理。
Args:
stream: クローズするストリーム
close_method: gRPCのクローズメソッド (例: self._stub.CloseSerialStream)
"""
# 先にクライアント側のストリームをキャンセル
stream.cancel()
# その後、サーバー側にクローズを通知
close_method(pb.Empty())
def open_serial_stream(self):
return self._open_stream(self._stub.OpenSerialStream, "Serial")
def close_serial_stream(self, stream):
self._close_stream(stream, self._stub.CloseSerialStream)
def create_serial_stream(self):
return SerialStream(self)
def open_can_stream(self):
return self._open_stream(self._stub.OpenCANStream, "CAN")
def close_can_stream(self, stream):
self._close_stream(stream, self._stub.CloseCANStream)
def create_can_stream(self):
return CANStream(self)
def open_udp_stream(self):
return self._open_stream(self._stub.OpenUDPStream, "UDP")
def close_udp_stream(self, stream):
self._close_stream(stream, self._stub.CloseUDPStream)
def create_udp_stream(self):
return UDPStream(self)
def get_latest_tcp_client(self,client_name: str,direction: str):
res = self._stub.GetLatestTCPClientData(pb.GetLatestTCPClientDataRequest(client_name=client_name,direction=direction))
return res.data
def wait_for_next_tcp_client_data(self, client_name: str, timeout_ms: int | None):
if timeout_ms is None:
timeout_ms = -1
elif timeout_ms <= 0:
raise ValueError("Cannot set up the 'timeout' 0 or less.")
res = self._stub.WaitForNextTCPClientData(pb.WaitForNextTCPClientDataRequest(client_name=client_name, timeout=timeout_ms))
return res.data
def open_tcp_client_stream(self):
return self._open_stream(self._stub.OpenTCPClientStream, "TCP Client")
def close_tcp_client_stream(self, stream):
self._close_stream(stream, self._stub.CloseTCPClientStream)
def create_tcp_client_stream(self):
return TCPClientStream(self)
def _get_env_value(self, key: str) -> str | None:
"""環境変数から値を取得します (内部実装用)。
Args:
key: 環境変数のキー
Returns:
環境変数の値。存在しない場合は None
Notes:
この関数は内部実装であり、将来的に取得方法が変更される可能性があります。
ユーザコードから直接呼び出さないでください。
"""
return os.getenv(key)
def connect_tcp_client(self, client_name: str):
self._stub.ConnectTCPClient(pb.TCPClientRequest(client_name=client_name))
def disconnect_tcp_client(self, client_name: str):
self._stub.DisconnectTCPClient(pb.TCPClientRequest(client_name=client_name))
def is_connected_tcp_client(self, client_name: str):
res = self._stub.IsConnectedTCPClient(pb.TCPClientRequest(client_name=client_name))
return res.is_connected
def send_tcp_client(self, client_name: str, data: bytes):
self._stub.SendTCPClient(pb.SendTCPClientRequest(client_name=client_name, data=data))
[docs]
class MG400ProjectThread(threading.Thread):
def __init__(self, target) -> None:
super().__init__(target=target, daemon=False)
self.exception: Exception | None = None
[docs]
def run(self) -> None:
try:
super().run()
except Exception as e:
self.exception = e
[docs]
def join(self, timeout_ms: int | None = None) -> None:
if timeout_ms is not None and timeout_ms <= 0:
raise ValueError("Cannot set up the 'timeout' 0 or less.")
timeout_sec = timeout_ms / 1000 if timeout_ms is not None else None
super().join(timeout=timeout_sec)
if self.exception:
raise self.exception
if self.is_alive():
raise TimeoutError("MG400ProjectThread join timed out.")
[docs]
class MG400State(Enum):
NONE = 0 # 未接続
ENABLE = 1 # 待機中
DISABLE = 2 # 動作不可
MOVING = 3 # 動作中
ERROR = 4 # エラー
OTHER = 5 # その他(AUTOmealが非対応の状態)
[docs]
class VDThread(threading.Thread):
def __init__(self, target) -> None:
super().__init__(target=target, daemon=False)
self.exception: Exception | None = None
[docs]
def run(self) -> None:
try:
super().run()
except Exception as e:
self.exception = e
[docs]
def join(self, timeout_ms: int | None = None) -> None:
if timeout_ms is not None and timeout_ms <= 0:
raise ValueError("Cannot set up the 'timeout' 0 or less.")
timeout_sec = timeout_ms / 1000 if timeout_ms is not None else None
super().join(timeout=timeout_sec)
if self.exception:
raise self.exception
if self.is_alive():
raise TimeoutError("VDThread join timed out.")
class PacketCommunicationStream:
"""ストリームのコンテキストマネージャーの基底クラス"""
def __init__(self, test_runner: TestRunner):
self._test_runner = test_runner
self.stream = None
def __enter__(self):
"""with ブロックに入るときに呼ばれる"""
self.stream = self._open_stream()
return self.stream
def __exit__(self, exc_type, exc_val, exc_tb):
"""with ブロックを抜けるときに必ず呼ばれる(例外が起きても)"""
try:
self._close_stream()
except Exception as e:
# クリーンアップ処理なので例外は記録するのみ
print(f"Warning: Error during stream cleanup: {e}", flush=True)
# False を返すことで with ブロック内の例外を抑制せず再発生させる(標準的な動作)
return False
def _open_stream(self):
"""サブクラスで実装する必要がある"""
raise NotImplementedError()
def _close_stream(self):
"""サブクラスで実装する必要がある"""
raise NotImplementedError()
[docs]
class SerialStream(PacketCommunicationStream):
def _open_stream(self):
return self._test_runner.open_serial_stream()
def _close_stream(self):
self._test_runner.close_serial_stream(self.stream)
[docs]
@dataclass
class SerialRecord:
"""シリアル通信のレコードを表すデータクラス。
Attributes:
time: タイムスタンプ(ミリ秒)
direction: 送受信方向 ("Receive")
port: シリアルポート名
data: 送受信データ(バイト列)
"""
time: timedelta
direction: str
port: str
data: bytes
[docs]
@classmethod
def from_stream_data(cls, stream_data: pb.SerialRecord) -> SerialRecord:
"""gRPC ストリームデータから SerialRecord を生成します。"""
return cls(
time=timedelta(milliseconds=stream_data.time),
direction=stream_data.direction,
port=stream_data.port,
data=stream_data.data
)
[docs]
class CANStream(PacketCommunicationStream):
def _open_stream(self):
return self._test_runner.open_can_stream()
def _close_stream(self):
self._test_runner.close_can_stream(self.stream)
[docs]
@dataclass
class CANRecord:
""" CAN通信のレコードを表すデータクラス。
Attributes:
time: タイムスタンプ(ミリ秒)
direction: 通信の方向 ("Receive")
identification_name: ハードウェアの識別名
channel: 使用されたCANチャネル番号。
frame_type: フレームタイプ("Data Frame" or "Error Frame")
can_id: CANメッセージのID
dlc: データ長コード
data: CANメッセージのデータペイロード
"""
time: timedelta
direction: str
identification_name: str
channel: int
frame_type: str
can_id: int
dlc: int
data: bytes
[docs]
@classmethod
def from_stream_data(cls, stream_data: pb.CANRecord) -> CANRecord:
""" gRPC ストリームデータから CANRecord を生成します。 """
return cls(
time=timedelta(milliseconds=stream_data.time),
direction=stream_data.direction,
identification_name=stream_data.identification_name,
channel=stream_data.channel,
frame_type=stream_data.frame_type,
can_id=stream_data.can_id,
dlc=stream_data.dlc,
data=stream_data.data
)
[docs]
class UDPStream(PacketCommunicationStream):
def _open_stream(self):
return self._test_runner.open_udp_stream()
def _close_stream(self):
self._test_runner.close_udp_stream(self.stream)
[docs]
@dataclass
class UDPRecord:
"""UDP 通信のレコードを表すデータクラス。
Attributes:
time: タイムスタンプ(ミリ秒)
direction: 送受信方向 ("Receive")
endpoint_name: エンドポイント名
local_ip: ローカル IP アドレス
local_port: ローカルポート番号
remote_ip: リモート IP アドレス
remote_port: リモートポート番号
data: 送受信データ(バイト列)
"""
time: timedelta
direction: str
endpoint_name: str
local_ip: str
local_port: int
remote_ip: str
remote_port: int
data: bytes
[docs]
@classmethod
def from_stream_data(cls, stream_data: pb.UDPRecord) -> UDPRecord:
"""gRPC ストリームデータから UDPRecord を生成します。"""
return cls(
time=timedelta(milliseconds=stream_data.time),
direction=stream_data.direction,
endpoint_name=stream_data.endpoint_name,
local_ip=stream_data.local_ip,
local_port=stream_data.local_port,
remote_ip=stream_data.remote_ip,
remote_port=stream_data.remote_port,
data=stream_data.data
)
[docs]
class TCPClientStream(PacketCommunicationStream):
def _open_stream(self):
return self._test_runner.open_tcp_client_stream()
def _close_stream(self):
self._test_runner.close_tcp_client_stream(self.stream)
[docs]
@dataclass
class TCPClientRecord:
"""TCP クライアント通信のレコードを表すデータクラス。
Attributes:
time: タイムスタンプ(ミリ秒)
direction: 送受信方向 ("Receive")
client_name: クライアント名
local_ip: ローカル IP アドレス
local_port: ローカルポート番号
server_ip: サーバ IP アドレス
server_port: サーバポート番号
data: 送受信データ(バイト列)
"""
time: timedelta
direction: str
client_name: str
local_ip: str
local_port: int
server_ip: str
server_port: int
data: bytes
[docs]
@classmethod
def from_stream_data(cls, stream_data: pb.TCPClientRecord) -> TCPClientRecord:
"""gRPC ストリームデータから TCPClientRecord を生成します。"""
return cls(
time=timedelta(milliseconds=stream_data.time),
direction=stream_data.direction,
client_name=stream_data.client_name,
local_ip=stream_data.local_ip,
local_port=stream_data.local_port,
server_ip=stream_data.server_ip,
server_port=stream_data.server_port,
data=stream_data.data
)
[docs]
@unique
class TestRunnerEnvKeys(Enum):
"""TestRunner から提供される環境変数のキー定数
"""
REPORT_SAVE_DIR = "AM_REPORT_SAVE_DIR"
[docs]
@unique
class TestParameterNames(Enum):
"""TestRunner から提供されるテストパラメータ名の列挙型
"""
REPORT_SAVE_DIRECTORY = "report_save_directory"