Source code for automeal.test_runner

from __future__ import annotations
import os
import threading
import traceback
import grpc
from enum import Enum, unique
from time import sleep
from .test_runner_protoc import test_runner_pb2 as pb
from .test_runner_protoc import test_runner_pb2_grpc as pbgrpc
from dataclasses import dataclass
from datetime import timedelta
import sys

class TestRunner:

    def __init__(self, port):
        self._channel =  grpc.insecure_channel(f'localhost:{port}')
        self._stub = pbgrpc.TestRunnerServiceStub(self._channel)
        return

    def __del__(self):
        self._channel.close()

    def start_test(self, folder_name:str):
        self._stub.StartTest(pb.StartTestRequest(folder_name=folder_name))

    def stop_test(self):
        self._stub.StopTest(pb.Empty())

    def send_serial(self,port_number,serial_data):
        self._stub.SendSerial(pb.SendSerialRequest(port_number=port_number, serial_data=serial_data))
        
    def send_serial_command(self,port_number,serial_command):
        self._stub.SendSerialCommand(pb.SendSerialCommandRequest(port_number=port_number, serial_command=serial_command))
    
    def get_latest_serial(self,port_number,direction):
        res = self._stub.GetLatestSerialData(pb.GetLatestSerialDataRequest(port_number=port_number,direction=direction))
        return res.serialData
    
    def wait_for_next_serial_data(self, port_number, timeout_ms):
        if timeout_ms is None:
            timeout_ms = -1
        elif timeout_ms <= 0:
            raise ValueError("Cannot set up the 'timeout' 0 or less.")

        res = self._stub.WaitForNextSerialData(pb.WaitForNextSerialDataRequest(port_number=port_number, timeout=timeout_ms))
        return res.serialData

    def write_logic_rpigp10(self, board_id: str, channel_number: int, state: bool):
        self._stub.WriteLogicRPiGP10(pb.WriteLogicRPiGP10Request(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), state=state))

    def write_logic_rpigp10_multi(self, board_id: str, states: list):
        values = [-1 if b is None else int(b) for b in states]
        self._stub.WriteLogicRPiGP10Multi(pb.WriteLogicRPiGP10MultiRequest(board_id=board_id, states=values))
    
    def write_logic_rpigp10_bits(self, board_id: str, state: int):
        self._stub.WriteLogicRPiGP10Bits(pb.WriteLogicRPiGP10BitsRequest(board_id=board_id, state=state))

    def read_logic_rpigp10(self, board_id: str, channel_number: int):
        res = self._stub.ReadLogicRPiGP10(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
        return res.state
    
    def write_analog_amao(self, board_id: str, channel_number: int, voltage_mv: int):
        self._stub.WriteAnalogAMAO(pb.WriteAnalogAMAORequest(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), voltage=voltage_mv))
    
    def read_analog_rpigp40(self, board_id: str, channel_number: int):
        res = self._stub.ReadAnalogRPiGP40(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
        return res.voltage
    
    def write_pwm_ampio(self, board_id: str, channel_number: int, frequency_hz: int, duty_percent: float):
        self._stub.WritePwmAMPIO(pb.WritePwmAMPIORequest(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), frequency=frequency_hz, duty=duty_percent))

    def read_pwm_ampio(self, board_id: str, channel_number: int):
        res = self._stub.ReadPwmAMPIO(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
        return res.frequency, res.duty
    
    def read_analog_cpiai1208li(self, board_id: str, channel_number: int):
        res = self._stub.ReadAnalogCPIAI1208LI(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
        return res.voltage
    
    def write_logic_cpirry16(self, board_id: str, channel_number: int, state: bool):
        self._stub.WriteLogicCpiRry16(pb.WriteLogicCpiRry16Request(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), state=state))

    def write_logic_cpirry16_multi(self, board_id: str, states: list):
        values = [-1 if b is None else int(b) for b in states]
        self._stub.WriteLogicCpiRry16Multi(pb.WriteLogicCpiRry16MultiRequest(board_id=board_id, states=values))
    
    def write_logic_cpirry16_bits(self, board_id: str, state: int):
        self._stub.WriteLogicCpiRry16Bits(pb.WriteLogicCpiRry16BitsRequest(board_id=board_id, state=state))
    
    def write_logic_cpidio0808l(self, board_id: str, channel_number: int, state: bool):
        self._stub.WriteLogicCpiDio0808L(pb.WriteLogicCpiDio0808LRequest(board=pb.BoardRequest(board_id=board_id, channel_number=channel_number), state=state))

    def write_logic_cpidio0808l_multi(self, board_id: str, states: list):
        values = [-1 if b is None else int(b) for b in states]
        self._stub.WriteLogicCpiDio0808LMulti(pb.WriteLogicCpiDio0808LMultiRequest(board_id=board_id, states=values))
    
    def write_logic_cpidio0808l_bits(self, board_id: str, state: int):
        self._stub.WriteLogicCpiDio0808LBits(pb.WriteLogicCpiDio0808LBitsRequest(board_id=board_id, state=state))

    def read_logic_cpidio0808l(self, board_id: str, channel_number: int):
        res = self._stub.ReadLogicCpiDio0808L(pb.BoardRequest(board_id=board_id, channel_number=channel_number))
        return res.state
    
    def send_can(self,user_defined_name, channel, can_id, can_data):
        self._stub.SendCan(pb.SendCanRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, can_data=can_data))
        
    def send_can_message(self, user_defined_name, channel, can_message_name):
        self._stub.SendCanMessage(pb.SendCanMessageRequest(user_defined_name=user_defined_name, channel=channel, can_message_name=can_message_name))
    
    def get_latest_can(self,user_defined_name, channel, can_id, direction):
        res = self._stub.GetLatestCanMessage(pb.GetLatestCanMessageRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, direction=direction))
        return res.can_data
    
    def wait_for_next_can_message(self,user_defined_name, channel, can_id, timeout_ms):
        if timeout_ms is None:
            timeout_ms = -1
        elif timeout_ms <= 0:
            raise ValueError("Cannot set up the 'timeout' 0 or less.")
        
        res = self._stub.WaitForNextCanMessage(pb.WaitForNextCanMessageRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, timeout=timeout_ms))
        return res.can_data

    def start_can_periodic(self,user_defined_name, channel, can_id, can_data, period):
        self._stub.StartCanPeriodic(pb.StartCanPeriodicRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id, can_data=can_data, period=period))
        
    def stop_can_periodic(self,user_defined_name, channel, can_id):
        if can_id is None:
            can_id = -1
        elif can_id < 0:
            raise ValueError("Cannot set up the 'can_id' less than 0.")
        self._stub.StopCanPeriodic(pb.StopCanPeriodicRequest(user_defined_name=user_defined_name, channel=channel, can_id=can_id))

    def start_can_message_periodic(self, user_defined_name, can_message_name):
        self._stub.StartCanMessagePeriodic(pb.StartCanMessagePeriodicRequest(user_defined_name=user_defined_name, can_message_name=can_message_name))

    def stop_can_message_periodic(self, user_defined_name, can_message_name):
        self._stub.StopCanMessagePeriodic(pb.StopCanMessagePeriodicRequest(user_defined_name=user_defined_name, can_message_name=can_message_name))

    def send_udp(self, endpoint_name, remote_ip, remote_port, bytes_data):
        self._stub.SendUDP(pb.SendUDPRequest(endpoint_name=endpoint_name, remote_ip=remote_ip, remote_port=remote_port, udp_data=bytes_data))

    def get_latest_udp(self, endpoint_name, direction):
        res = self._stub.GetLatestUDPData(pb.GetLatestUDPDataRequest(endpoint_name=endpoint_name, direction=direction))
        return res.udp_data

    def wait_for_next_udp_data(self, endpoint_name, timeout_ms):
        if timeout_ms is None:
            timeout_ms = -1
        elif timeout_ms <= 0:
            raise ValueError("Cannot set up the 'timeout' 0 or less.")

        res = self._stub.WaitForNextUDPData(pb.WaitForNextUDPDataRequest(endpoint_name=endpoint_name, timeout=timeout_ms))
        return res.udp_data
    def start_mg400_project(self, mg400_name: str, project_name: str) -> MG400ProjectThread:
        
        thread = MG400ProjectThread(
            target=lambda: self._stub.RunMG400Project(pb.RunMG400ProjectRequest(mg400_name=mg400_name, project_name=project_name))
        )
        thread.start()
        max_retries = 100  # 100*0.1 = 10 seconds
        retries=0
        while self.get_mg400_state(mg400_name) == MG400State.ENABLE:
            if not thread.is_alive():
                raise thread.exception if thread.exception else RuntimeError("Failed to start MG400 project.")
            if retries >= max_retries:
                raise TimeoutError(f"MG400 project did not start within {max_retries * 0.1} seconds.")
            sleep(0.1)
            retries += 1
        return thread

    def stop_mg400_project(self, mg400_name: str):
        self._stub.StopMG400Project(pb.StopMG400ProjectRequest(mg400_name=mg400_name))

    def get_mg400_state(self, mg400_name: str):
        res = self._stub.GetMG400State(pb.GetMG400StateRequest(mg400_name=mg400_name))
        return MG400State(res.state)

    def load_vd_config(self, config_file_path: str):
        """Connected VD の設定ファイルをロードしてサーバーを起動する。

        Args:
            config_file_path: 設定ファイルのパス
        """
        self._stub.LoadCoVDConfig(pb.LoadCoVDConfigRequest(config_file_path=config_file_path))

    def update_vd_config(self, **kwargs):
        """Connected VD の設定を部分的に更新する。

        Args:
            **kwargs: 更新したいパラメータをキーワード引数で指定

        例:
            update_vd_config(
                input_media_path="C:/videos/test.mp4",
                output_path="C:/results"
            )
        """
        # Python の snake_case キーを文字列値に変換
        config_map = {}
        for key, value in kwargs.items():
            # リストや辞書は JSON シリアライズ
            if isinstance(value, (list, dict)):
                import json
                config_map[key] = json.dumps(value)
            else:
                config_map[key] = str(value)

        self._stub.UpdateCoVDConfig(pb.UpdateCoVDConfigRequest(config=config_map))

    def start_vd_judge(self):
        """Connected VD の判定を開始する(設定ロード済み前提)。

        Returns:
            判定完了を待機するスレッド
        """
        self._stub.StartCoVDJudge(pb.Empty())
        thread = VDThread(
            target=lambda: self._stub.WaitCoVDJudgeFinish(pb.Empty())
        )
        thread.start()
        return thread

    def stop_vd_judge(self):
        self._stub.StopCoVDJudge(pb.Empty())

    def get_vd_log_list(self):
        res = self._stub.GetCoVDLogList(pb.Empty())
        return list(res.log_items)

    STREAM_OPEN_MESSAGE = "STREAM_OPEN"

    def _open_stream(self, open_method, stream_type: str, max_retries: int = 100):
        """ストリームを開く共通処理。
        
        Args:
            open_method: gRPCの開始メソッド (例: self._stub.OpenSerialStream)
            stream_type: ストリームタイプ名(エラーメッセージ用)
            max_retries: 最大リトライ回数(デフォルト: 100)
        
        Returns:
            開かれたストリーム
        
        Raises:
            TimeoutError: タイムアウト時
        """
        res = open_method(pb.Empty())
        retries = 0
        for message in res:
            try:
                decoded = message.data.decode("utf-8")
            except UnicodeDecodeError:
                # デコードに失敗した場合は、STREAM_OPEN_MESSAGE でないとみなして続行
                decoded = None

            if decoded == self.STREAM_OPEN_MESSAGE:
                break

            # STREAM_OPENメッセージが来るまで最大max_retriesメッセージ受信
            retries += 1
            if retries >= max_retries:
                raise TimeoutError(f"{stream_type} stream open timed out after {max_retries} messages.")
        return res

    def _close_stream(self, stream, close_method):
        """ストリームを閉じる共通処理。
        
        Args:
            stream: クローズするストリーム
            close_method: gRPCのクローズメソッド (例: self._stub.CloseSerialStream)
        """
        # 先にクライアント側のストリームをキャンセル
        stream.cancel()
        # その後、サーバー側にクローズを通知
        close_method(pb.Empty())

    def open_serial_stream(self):
        return self._open_stream(self._stub.OpenSerialStream, "Serial")

    def close_serial_stream(self, stream):
        self._close_stream(stream, self._stub.CloseSerialStream)
    
    def create_serial_stream(self):
        return SerialStream(self)

    def open_can_stream(self):
        return self._open_stream(self._stub.OpenCANStream, "CAN")

    def close_can_stream(self, stream):
        self._close_stream(stream, self._stub.CloseCANStream)

    def create_can_stream(self):
        return CANStream(self)

    def open_udp_stream(self):
        return self._open_stream(self._stub.OpenUDPStream, "UDP")
    
    def close_udp_stream(self, stream):
        self._close_stream(stream, self._stub.CloseUDPStream)
        
    def create_udp_stream(self):
        return UDPStream(self)

    def get_latest_tcp_client(self,client_name: str,direction: str):
        res = self._stub.GetLatestTCPClientData(pb.GetLatestTCPClientDataRequest(client_name=client_name,direction=direction))
        return res.data
    
    def wait_for_next_tcp_client_data(self, client_name: str, timeout_ms: int | None):
        if timeout_ms is None:
            timeout_ms = -1
        elif timeout_ms <= 0:
            raise ValueError("Cannot set up the 'timeout' 0 or less.")

        res = self._stub.WaitForNextTCPClientData(pb.WaitForNextTCPClientDataRequest(client_name=client_name, timeout=timeout_ms))
        return res.data

    def open_tcp_client_stream(self):
        return self._open_stream(self._stub.OpenTCPClientStream, "TCP Client")

    def close_tcp_client_stream(self, stream):
        self._close_stream(stream, self._stub.CloseTCPClientStream)

    def create_tcp_client_stream(self):
        return TCPClientStream(self)
    
    def _get_env_value(self, key: str) -> str | None:
        """環境変数から値を取得します (内部実装用)。
        
        Args:
            key: 環境変数のキー
            
        Returns:
            環境変数の値。存在しない場合は None
        
        Notes:
            この関数は内部実装であり、将来的に取得方法が変更される可能性があります。
            ユーザコードから直接呼び出さないでください。
        """
        return os.getenv(key)

    def connect_tcp_client(self, client_name: str):
        self._stub.ConnectTCPClient(pb.TCPClientRequest(client_name=client_name))

    def disconnect_tcp_client(self, client_name: str):
        self._stub.DisconnectTCPClient(pb.TCPClientRequest(client_name=client_name))

    def is_connected_tcp_client(self, client_name: str):
        res = self._stub.IsConnectedTCPClient(pb.TCPClientRequest(client_name=client_name))
        return res.is_connected

    def send_tcp_client(self, client_name: str, data: bytes):
        self._stub.SendTCPClient(pb.SendTCPClientRequest(client_name=client_name, data=data))


[docs] class MG400ProjectThread(threading.Thread): def __init__(self, target) -> None: super().__init__(target=target, daemon=False) self.exception: Exception | None = None
[docs] def run(self) -> None: try: super().run() except Exception as e: self.exception = e
[docs] def join(self, timeout_ms: int | None = None) -> None: if timeout_ms is not None and timeout_ms <= 0: raise ValueError("Cannot set up the 'timeout' 0 or less.") timeout_sec = timeout_ms / 1000 if timeout_ms is not None else None super().join(timeout=timeout_sec) if self.exception: raise self.exception if self.is_alive(): raise TimeoutError("MG400ProjectThread join timed out.")
[docs] class MG400State(Enum): NONE = 0 # 未接続 ENABLE = 1 # 待機中 DISABLE = 2 # 動作不可 MOVING = 3 # 動作中 ERROR = 4 # エラー OTHER = 5 # その他(AUTOmealが非対応の状態)
[docs] class VDThread(threading.Thread): def __init__(self, target) -> None: super().__init__(target=target, daemon=False) self.exception: Exception | None = None
[docs] def run(self) -> None: try: super().run() except Exception as e: self.exception = e
[docs] def join(self, timeout_ms: int | None = None) -> None: if timeout_ms is not None and timeout_ms <= 0: raise ValueError("Cannot set up the 'timeout' 0 or less.") timeout_sec = timeout_ms / 1000 if timeout_ms is not None else None super().join(timeout=timeout_sec) if self.exception: raise self.exception if self.is_alive(): raise TimeoutError("VDThread join timed out.")
class PacketCommunicationStream: """ストリームのコンテキストマネージャーの基底クラス""" def __init__(self, test_runner: TestRunner): self._test_runner = test_runner self.stream = None def __enter__(self): """with ブロックに入るときに呼ばれる""" self.stream = self._open_stream() return self.stream def __exit__(self, exc_type, exc_val, exc_tb): """with ブロックを抜けるときに必ず呼ばれる(例外が起きても)""" try: self._close_stream() except Exception as e: # クリーンアップ処理なので例外は記録するのみ print(f"Warning: Error during stream cleanup: {e}", flush=True) # False を返すことで with ブロック内の例外を抑制せず再発生させる(標準的な動作) return False def _open_stream(self): """サブクラスで実装する必要がある""" raise NotImplementedError() def _close_stream(self): """サブクラスで実装する必要がある""" raise NotImplementedError()
[docs] class SerialStream(PacketCommunicationStream): def _open_stream(self): return self._test_runner.open_serial_stream() def _close_stream(self): self._test_runner.close_serial_stream(self.stream)
[docs] @dataclass class SerialRecord: """シリアル通信のレコードを表すデータクラス。 Attributes: time: タイムスタンプ(ミリ秒) direction: 送受信方向 ("Receive") port: シリアルポート名 data: 送受信データ(バイト列) """ time: timedelta direction: str port: str data: bytes
[docs] @classmethod def from_stream_data(cls, stream_data: pb.SerialRecord) -> SerialRecord: """gRPC ストリームデータから SerialRecord を生成します。""" return cls( time=timedelta(milliseconds=stream_data.time), direction=stream_data.direction, port=stream_data.port, data=stream_data.data )
[docs] class CANStream(PacketCommunicationStream): def _open_stream(self): return self._test_runner.open_can_stream() def _close_stream(self): self._test_runner.close_can_stream(self.stream)
[docs] @dataclass class CANRecord: """ CAN通信のレコードを表すデータクラス。 Attributes: time: タイムスタンプ(ミリ秒) direction: 通信の方向 ("Receive") identification_name: ハードウェアの識別名 channel: 使用されたCANチャネル番号。 frame_type: フレームタイプ("Data Frame" or "Error Frame") can_id: CANメッセージのID dlc: データ長コード data: CANメッセージのデータペイロード """ time: timedelta direction: str identification_name: str channel: int frame_type: str can_id: int dlc: int data: bytes
[docs] @classmethod def from_stream_data(cls, stream_data: pb.CANRecord) -> CANRecord: """ gRPC ストリームデータから CANRecord を生成します。 """ return cls( time=timedelta(milliseconds=stream_data.time), direction=stream_data.direction, identification_name=stream_data.identification_name, channel=stream_data.channel, frame_type=stream_data.frame_type, can_id=stream_data.can_id, dlc=stream_data.dlc, data=stream_data.data )
[docs] class UDPStream(PacketCommunicationStream): def _open_stream(self): return self._test_runner.open_udp_stream() def _close_stream(self): self._test_runner.close_udp_stream(self.stream)
[docs] @dataclass class UDPRecord: """UDP 通信のレコードを表すデータクラス。 Attributes: time: タイムスタンプ(ミリ秒) direction: 送受信方向 ("Receive") endpoint_name: エンドポイント名 local_ip: ローカル IP アドレス local_port: ローカルポート番号 remote_ip: リモート IP アドレス remote_port: リモートポート番号 data: 送受信データ(バイト列) """ time: timedelta direction: str endpoint_name: str local_ip: str local_port: int remote_ip: str remote_port: int data: bytes
[docs] @classmethod def from_stream_data(cls, stream_data: pb.UDPRecord) -> UDPRecord: """gRPC ストリームデータから UDPRecord を生成します。""" return cls( time=timedelta(milliseconds=stream_data.time), direction=stream_data.direction, endpoint_name=stream_data.endpoint_name, local_ip=stream_data.local_ip, local_port=stream_data.local_port, remote_ip=stream_data.remote_ip, remote_port=stream_data.remote_port, data=stream_data.data )
[docs] class TCPClientStream(PacketCommunicationStream): def _open_stream(self): return self._test_runner.open_tcp_client_stream() def _close_stream(self): self._test_runner.close_tcp_client_stream(self.stream)
[docs] @dataclass class TCPClientRecord: """TCP クライアント通信のレコードを表すデータクラス。 Attributes: time: タイムスタンプ(ミリ秒) direction: 送受信方向 ("Receive") client_name: クライアント名 local_ip: ローカル IP アドレス local_port: ローカルポート番号 server_ip: サーバ IP アドレス server_port: サーバポート番号 data: 送受信データ(バイト列) """ time: timedelta direction: str client_name: str local_ip: str local_port: int server_ip: str server_port: int data: bytes
[docs] @classmethod def from_stream_data(cls, stream_data: pb.TCPClientRecord) -> TCPClientRecord: """gRPC ストリームデータから TCPClientRecord を生成します。""" return cls( time=timedelta(milliseconds=stream_data.time), direction=stream_data.direction, client_name=stream_data.client_name, local_ip=stream_data.local_ip, local_port=stream_data.local_port, server_ip=stream_data.server_ip, server_port=stream_data.server_port, data=stream_data.data )
[docs] @unique class TestRunnerEnvKeys(Enum): """TestRunner から提供される環境変数のキー定数 """ REPORT_SAVE_DIR = "AM_REPORT_SAVE_DIR"
[docs] @unique class TestParameterNames(Enum): """TestRunner から提供されるテストパラメータ名の列挙型 """ REPORT_SAVE_DIRECTORY = "report_save_directory"