Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util/agents/usb_hid_relay: fix concurrent access #1526

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions labgrid/util/agents/usb_hid_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
- Turn digital output on and off
"""

import errno
from contextlib import contextmanager
from time import monotonic, sleep

import usb.core
import usb.util

Expand All @@ -26,18 +30,35 @@ def __init__(self, **args):
raise ValueError("Device not found")

if self._dev.idVendor == 0x16C0:
self.set_output = self.set_output_dcttech
self.get_output = self.get_output_dcttech
self._set_output = self._set_output_dcttech
self._get_output = self._get_output_dcttech
elif self._dev.idVendor == 0x5131:
self.set_output = self.set_output_lcus
self.get_output = self.get_output_lcus
self._set_output = self._set_output_lcus
self._get_output = self._get_output_lcus
else:
raise ValueError(f"Unknown vendor/protocol for VID {self._dev.idVendor:x}")

if self._dev.is_kernel_driver_active(0):
self._dev.detach_kernel_driver(0)

def set_output_dcttech(self, number, status):
@contextmanager
def _claimed(self):
timeout = monotonic() + 1.0
while True:
try:
usb.util.claim_interface(self._dev, 0)
break
except usb.core.USBError as e:
if monotonic() > timeout:
raise e
if e.errno == errno.EBUSY:
sleep(0.01)
else:
raise e
yield
usb.util.release_interface(self._dev, 0)

def _set_output_dcttech(self, number, status):
assert 1 <= number <= 8
req = [0xFF if status else 0xFD, number]
self._dev.ctrl_transfer(
Expand All @@ -48,7 +69,7 @@ def set_output_dcttech(self, number, status):
req, # payload
)

def get_output_dcttech(self, number):
def _get_output_dcttech(self, number):
assert 1 <= number <= 8
resp = self._dev.ctrl_transfer(
usb.util.CTRL_TYPE_CLASS | usb.util.CTRL_RECIPIENT_DEVICE | usb.util.ENDPOINT_IN,
Expand All @@ -59,7 +80,7 @@ def get_output_dcttech(self, number):
)
return bool(resp[7] & (1 << (number - 1)))

def set_output_lcus(self, number, status):
def _set_output_lcus(self, number, status):
assert 1 <= number <= 8
ep_in = self._dev[0][(0, 0)][0]
ep_out = self._dev[0][(0, 0)][1]
Expand All @@ -68,13 +89,18 @@ def set_output_lcus(self, number, status):
ep_out.write(req)
ep_in.read(64)

def get_output_lcus(self, number):
def _get_output_lcus(self, number):
assert 1 <= number <= 8
# we have no information on how to read the current value
return False

def __del__(self):
usb.util.release_interface(self._dev, 0)
def set_output(self, number, status):
with self._claimed():
self._set_output(number, status)

def get_output(self, number):
with self._claimed():
self._get_output(number)


_relays = {}
Expand Down