diff --git a/adb/adb.cpp b/adb/adb.cpp index 51d5876f5fe8..39e71e5f3b2f 100644 --- a/adb/adb.cpp +++ b/adb/adb.cpp @@ -253,6 +253,19 @@ void send_connect(atransport* t) { send_packet(cp, t); } +#if ADB_HOST + +void SendConnectOnHost(atransport* t) { + // Send an empty message before A_CNXN message. This is because the data toggle of the ep_out on + // host and ep_in on device may not be the same. + apacket* p = get_apacket(); + CHECK(p); + send_packet(p, t); + send_connect(t); +} + +#endif + // qual_overwrite is used to overwrite a qualifier string. dst is a // pointer to a char pointer. It is assumed that if *dst is non-NULL, it // was malloc'ed and needs to freed. *dst will be set to a dup of src. @@ -299,29 +312,29 @@ void parse_banner(const std::string& banner, atransport* t) { const std::string& type = pieces[0]; if (type == "bootloader") { D("setting connection_state to kCsBootloader"); - t->connection_state = kCsBootloader; + t->SetConnectionState(kCsBootloader); update_transports(); } else if (type == "device") { D("setting connection_state to kCsDevice"); - t->connection_state = kCsDevice; + t->SetConnectionState(kCsDevice); update_transports(); } else if (type == "recovery") { D("setting connection_state to kCsRecovery"); - t->connection_state = kCsRecovery; + t->SetConnectionState(kCsRecovery); update_transports(); } else if (type == "sideload") { D("setting connection_state to kCsSideload"); - t->connection_state = kCsSideload; + t->SetConnectionState(kCsSideload); update_transports(); } else { D("setting connection_state to kCsHost"); - t->connection_state = kCsHost; + t->SetConnectionState(kCsHost); } } static void handle_new_connection(atransport* t, apacket* p) { - if (t->connection_state != kCsOffline) { - t->connection_state = kCsOffline; + if (t->GetConnectionState() != kCsOffline) { + t->SetConnectionState(kCsOffline); handle_offline(t); } @@ -355,10 +368,10 @@ void handle_packet(apacket *p, atransport *t) if (p->msg.arg0){ send_packet(p, t); #if ADB_HOST - send_connect(t); + SendConnectOnHost(t); #endif } else { - t->connection_state = kCsOffline; + t->SetConnectionState(kCsOffline); handle_offline(t); send_packet(p, t); } @@ -372,7 +385,9 @@ void handle_packet(apacket *p, atransport *t) switch (p->msg.arg0) { #if ADB_HOST case ADB_AUTH_TOKEN: - t->connection_state = kCsUnauthorized; + if (t->GetConnectionState() == kCsOffline) { + t->SetConnectionState(kCsUnauthorized); + } send_auth_response(p->data, p->msg.data_length, t); break; #else @@ -391,7 +406,7 @@ void handle_packet(apacket *p, atransport *t) break; #endif default: - t->connection_state = kCsOffline; + t->SetConnectionState(kCsOffline); handle_offline(t); break; } @@ -1032,7 +1047,6 @@ static int SendOkay(int fd, const std::string& s) { SendProtocolString(fd, s); return 0; } -#endif int handle_host_request(const char* service, TransportType type, const char* serial, int reply_fd, asocket* s) { @@ -1051,7 +1065,6 @@ int handle_host_request(const char* service, TransportType type, android::base::quick_exit(0); } -#if ADB_HOST // "transport:" is used for switching transport with a specified serial number // "transport-usb:" is used for switching transport to the only USB transport // "transport-local:" is used for switching transport to the only local transport @@ -1096,16 +1109,10 @@ int handle_host_request(const char* service, TransportType type, if (!strcmp(service, "reconnect-offline")) { std::string response; close_usb_devices([&response](const atransport* transport) { - switch (transport->connection_state) { + switch (transport->GetConnectionState()) { case kCsOffline: case kCsUnauthorized: - response += "reconnecting "; - if (transport->serial) { - response += transport->serial; - } else { - response += ""; - } - response += "\n"; + response += "reconnecting " + transport->serial_name() + "\n"; return true; default: return false; @@ -1129,7 +1136,6 @@ int handle_host_request(const char* service, TransportType type, return 0; } -#if ADB_HOST if (!strcmp(service, "host-features")) { FeatureSet features = supported_features(); // Abuse features to report libusb status. @@ -1139,7 +1145,6 @@ int handle_host_request(const char* service, TransportType type, SendOkay(reply_fd, FeatureSetToString(features)); return 0; } -#endif // remove TCP transport if (!strncmp(service, "disconnect:", 11)) { @@ -1209,15 +1214,19 @@ int handle_host_request(const char* service, TransportType type, } if (!strcmp(service, "reconnect")) { - if (s->transport != nullptr) { - kick_transport(s->transport); + std::string response; + atransport* t = acquire_one_transport(type, serial, nullptr, &response, true); + if (t != nullptr) { + kick_transport(t); + response = + "reconnecting " + t->serial_name() + " [" + t->connection_state_name() + "]\n"; } - return SendOkay(reply_fd, "done"); + return SendOkay(reply_fd, response); } -#endif // ADB_HOST int ret = handle_forward_request(service, type, serial, reply_fd); if (ret >= 0) return ret - 1; return -1; } +#endif // ADB_HOST diff --git a/adb/adb.h b/adb/adb.h index aea5fb86b2a8..e3675d8442d7 100644 --- a/adb/adb.h +++ b/adb/adb.h @@ -139,7 +139,7 @@ int adb_server_main(int is_daemon, const std::string& socket_spec, int ack_reply int get_available_local_transport_index(); #endif int init_socket_transport(atransport *t, int s, int port, int local); -void init_usb_transport(atransport *t, usb_handle *usb, ConnectionState state); +void init_usb_transport(atransport* t, usb_handle* usb); std::string getEmulatorSerialString(int console_port); #if ADB_HOST @@ -222,6 +222,9 @@ void handle_online(atransport *t); void handle_offline(atransport *t); void send_connect(atransport *t); +#if ADB_HOST +void SendConnectOnHost(atransport* t); +#endif void parse_banner(const std::string&, atransport* t); diff --git a/adb/adb_client.cpp b/adb/adb_client.cpp index b2b5c0ee00ee..b6568875cabb 100644 --- a/adb/adb_client.cpp +++ b/adb/adb_client.cpp @@ -136,8 +136,7 @@ int _adb_connect(const std::string& service, std::string* error) { return -2; } - if ((memcmp(&service[0],"host",4) != 0 || service == "host:reconnect") && - switch_socket_transport(fd, error)) { + if (memcmp(&service[0], "host", 4) != 0 && switch_socket_transport(fd, error)) { return -1; } @@ -147,11 +146,9 @@ int _adb_connect(const std::string& service, std::string* error) { return -1; } - if (service != "reconnect") { - if (!adb_status(fd, error)) { - adb_close(fd); - return -1; - } + if (!adb_status(fd, error)) { + adb_close(fd); + return -1; } D("_adb_connect: return fd %d", fd); diff --git a/adb/adb_trace.cpp b/adb/adb_trace.cpp index c369d6077e04..eac923d4d8f2 100644 --- a/adb/adb_trace.cpp +++ b/adb/adb_trace.cpp @@ -155,7 +155,7 @@ void adb_trace_init(char** argv) { } #endif -#if !defined(_WIN32) +#if ADB_HOST && !defined(_WIN32) // adb historically ignored $ANDROID_LOG_TAGS but passed it through to logcat. // If set, move it out of the way so that libbase logging doesn't try to parse it. std::string log_tags; @@ -168,7 +168,7 @@ void adb_trace_init(char** argv) { android::base::InitLogging(argv, &AdbLogger); -#if !defined(_WIN32) +#if ADB_HOST && !defined(_WIN32) // Put $ANDROID_LOG_TAGS back so we can pass it to logcat. if (!log_tags.empty()) setenv("ANDROID_LOG_TAGS", log_tags.c_str(), 1); #endif diff --git a/adb/adb_trace.h b/adb/adb_trace.h index aaffa296794d..fc6560cfbade 100644 --- a/adb/adb_trace.h +++ b/adb/adb_trace.h @@ -58,6 +58,9 @@ extern int adb_trace_mask; void adb_trace_init(char**); void adb_trace_enable(AdbTrace trace_tag); +// Include before stdatomic.h (introduced in cutils/trace.h) to avoid compile error. +#include + #define ATRACE_TAG ATRACE_TAG_ADB #include #include diff --git a/adb/client/usb_libusb.cpp b/adb/client/usb_libusb.cpp index c48a2517b510..fec4742b2236 100644 --- a/adb/client/usb_libusb.cpp +++ b/adb/client/usb_libusb.cpp @@ -62,12 +62,11 @@ struct DeviceHandleDeleter { using unique_device_handle = std::unique_ptr; struct transfer_info { - transfer_info(const char* name, uint16_t zero_mask) : - name(name), - transfer(libusb_alloc_transfer(0)), - zero_mask(zero_mask) - { - } + transfer_info(const char* name, uint16_t zero_mask, bool is_bulk_out) + : name(name), + transfer(libusb_alloc_transfer(0)), + is_bulk_out(is_bulk_out), + zero_mask(zero_mask) {} ~transfer_info() { libusb_free_transfer(transfer); @@ -75,6 +74,7 @@ struct transfer_info { const char* name; libusb_transfer* transfer; + bool is_bulk_out; bool transfer_complete; std::condition_variable cv; std::mutex mutex; @@ -96,12 +96,11 @@ struct usb_handle : public ::usb_handle { serial(serial), closing(false), device_handle(device_handle.release()), - read("read", zero_mask), - write("write", zero_mask), + read("read", zero_mask, false), + write("write", zero_mask, true), interface(interface), bulk_in(bulk_in), - bulk_out(bulk_out) { - } + bulk_out(bulk_out) {} ~usb_handle() { Close(); @@ -365,11 +364,6 @@ void usb_init() { device_poll_thread = new std::thread(poll_for_devices); android::base::at_quick_exit([]() { terminate_device_poll_thread = true; - std::unique_lock lock(usb_handles_mutex); - for (auto& it : usb_handles) { - it.second->Close(); - } - lock.unlock(); device_poll_thread->join(); }); } @@ -397,7 +391,8 @@ static int perform_usb_transfer(usb_handle* h, transfer_info* info, return; } - if (transfer->actual_length != transfer->length) { + // usb_read() can return when receiving some data. + if (info->is_bulk_out && transfer->actual_length != transfer->length) { LOG(DEBUG) << info->name << " transfer incomplete, resubmitting"; transfer->length -= transfer->actual_length; transfer->buffer += transfer->actual_length; @@ -491,8 +486,12 @@ int usb_read(usb_handle* h, void* d, int len) { info->transfer->num_iso_packets = 0; int rc = perform_usb_transfer(h, info, std::move(lock)); - LOG(DEBUG) << "usb_read(" << len << ") = " << rc; - return rc; + LOG(DEBUG) << "usb_read(" << len << ") = " << rc << ", actual_length " + << info->transfer->actual_length; + if (rc < 0) { + return rc; + } + return info->transfer->actual_length; } int usb_close(usb_handle* h) { diff --git a/adb/client/usb_linux.cpp b/adb/client/usb_linux.cpp index 3a45dbd711d9..6efed274b26d 100644 --- a/adb/client/usb_linux.cpp +++ b/adb/client/usb_linux.cpp @@ -401,7 +401,6 @@ static int usb_bulk_read(usb_handle* h, void* data, int len) { } } - int usb_write(usb_handle *h, const void *_data, int len) { D("++ usb_write ++"); @@ -429,19 +428,16 @@ int usb_read(usb_handle *h, void *_data, int len) int n; D("++ usb_read ++"); - while(len > 0) { + int orig_len = len; + while (len == orig_len) { int xfer = len; D("[ usb read %d fd = %d], path=%s", xfer, h->fd, h->path.c_str()); n = usb_bulk_read(h, data, xfer); D("[ usb read %d ] = %d, path=%s", xfer, n, h->path.c_str()); - if(n != xfer) { + if (n <= 0) { if((errno == ETIMEDOUT) && (h->fd != -1)) { D("[ timeout ]"); - if(n > 0){ - data += n; - len -= n; - } continue; } D("ERROR: n = %d, errno = %d (%s)", @@ -449,12 +445,12 @@ int usb_read(usb_handle *h, void *_data, int len) return -1; } - len -= xfer; - data += xfer; + len -= n; + data += n; } D("-- usb_read --"); - return 0; + return orig_len - len; } void usb_kick(usb_handle* h) { diff --git a/adb/client/usb_osx.cpp b/adb/client/usb_osx.cpp index 8713b2c4ba19..fcd0bc044b55 100644 --- a/adb/client/usb_osx.cpp +++ b/adb/client/usb_osx.cpp @@ -518,7 +518,7 @@ int usb_read(usb_handle *handle, void *buf, int len) } if (kIOReturnSuccess == result) - return 0; + return numBytes; else { LOG(ERROR) << "usb_read failed with status: " << std::hex << result; } diff --git a/adb/client/usb_windows.cpp b/adb/client/usb_windows.cpp index 9e00a5d6ceca..ee7f8024fd38 100644 --- a/adb/client/usb_windows.cpp +++ b/adb/client/usb_windows.cpp @@ -415,6 +415,7 @@ int usb_read(usb_handle *handle, void* data, int len) { unsigned long time_out = 0; unsigned long read = 0; int err = 0; + int orig_len = len; D("usb_read %d", len); if (NULL == handle) { @@ -423,9 +424,8 @@ int usb_read(usb_handle *handle, void* data, int len) { goto fail; } - while (len > 0) { - if (!AdbReadEndpointSync(handle->adb_read_pipe, data, len, &read, - time_out)) { + while (len == orig_len) { + if (!AdbReadEndpointSync(handle->adb_read_pipe, data, len, &read, time_out)) { D("AdbReadEndpointSync failed: %s", android::base::SystemErrorCodeToString(GetLastError()).c_str()); err = EIO; @@ -433,11 +433,11 @@ int usb_read(usb_handle *handle, void* data, int len) { } D("usb_read got: %ld, expected: %d", read, len); - data = (char *)data + read; + data = (char*)data + read; len -= read; } - return 0; + return orig_len - len; fail: // Any failure should cause us to kick the device instead of leaving it a diff --git a/adb/commandline.cpp b/adb/commandline.cpp index d626259f2673..5f55ab984e52 100644 --- a/adb/commandline.cpp +++ b/adb/commandline.cpp @@ -212,6 +212,7 @@ static void help() { " kill-server kill the server if it is running\n" " reconnect kick connection from host side to force reconnect\n" " reconnect device kick connection from device side to force reconnect\n" + " reconnect offline reset offline/unauthorized devices to force reconnect\n" "\n" "environment variables:\n" " $ADB_TRACE\n" @@ -1929,7 +1930,7 @@ int adb_commandline(int argc, const char** argv) { return adb_query_command("host:host-features"); } else if (!strcmp(argv[0], "reconnect")) { if (argc == 1) { - return adb_query_command("host:reconnect"); + return adb_query_command(format_host_command(argv[0], transport_type, serial)); } else if (argc == 2) { if (!strcmp(argv[1], "device")) { std::string err; diff --git a/adb/fdevent.cpp b/adb/fdevent.cpp index 04cd8651cfff..72c9eef42523 100644 --- a/adb/fdevent.cpp +++ b/adb/fdevent.cpp @@ -75,13 +75,13 @@ static std::atomic terminate_loop(false); static bool main_thread_valid; static unsigned long main_thread_id; -static void check_main_thread() { +void check_main_thread() { if (main_thread_valid) { CHECK_EQ(main_thread_id, adb_thread_id()); } } -static void set_main_thread() { +void set_main_thread() { main_thread_valid = true; main_thread_id = adb_thread_id(); } diff --git a/adb/fdevent.h b/adb/fdevent.h index 207f9b702893..e32845afca56 100644 --- a/adb/fdevent.h +++ b/adb/fdevent.h @@ -76,9 +76,12 @@ void fdevent_set_timeout(fdevent *fde, int64_t timeout_ms); */ void fdevent_loop(); +void check_main_thread(); + // The following functions are used only for tests. void fdevent_terminate_loop(); size_t fdevent_installed_count(); void fdevent_reset(); +void set_main_thread(); #endif diff --git a/adb/services.cpp b/adb/services.cpp index 43270445162c..9605e6ec07b9 100644 --- a/adb/services.cpp +++ b/adb/services.cpp @@ -347,7 +347,7 @@ static void wait_for_state(int fd, void* data) { std::string error = "unknown error"; const char* serial = sinfo->serial.length() ? sinfo->serial.c_str() : NULL; atransport* t = acquire_one_transport(sinfo->transport_type, serial, &is_ambiguous, &error); - if (t != nullptr && (sinfo->state == kCsAny || sinfo->state == t->connection_state)) { + if (t != nullptr && (sinfo->state == kCsAny || sinfo->state == t->GetConnectionState())) { SendOkay(fd); break; } else if (!is_ambiguous) { diff --git a/adb/sockets.cpp b/adb/sockets.cpp index 59a48f56d1d5..14ad1ff97fe0 100644 --- a/adb/sockets.cpp +++ b/adb/sockets.cpp @@ -794,7 +794,7 @@ static int smart_socket_enqueue(asocket* s, apacket* p) { if (!s->transport) { SendFail(s->peer->fd, "device offline (no transport)"); goto fail; - } else if (s->transport->connection_state == kCsOffline) { + } else if (s->transport->GetConnectionState() == kCsOffline) { /* if there's no remote we fail the connection ** right here and terminate it */ diff --git a/adb/test_device.py b/adb/test_device.py index e76aaed5c054..a30972e54a95 100644 --- a/adb/test_device.py +++ b/adb/test_device.py @@ -1188,6 +1188,77 @@ def test_unicode_paths(self): self.device.shell(['rm', '-f', '/data/local/tmp/adb-test-*']) +class DeviceOfflineTest(DeviceTest): + def _get_device_state(self, serialno): + output = subprocess.check_output(self.device.adb_cmd + ['devices']) + for line in output.split('\n'): + m = re.match('(\S+)\s+(\S+)', line) + if m and m.group(1) == serialno: + return m.group(2) + return None + + def test_killed_when_pushing_a_large_file(self): + """ + While running adb push with a large file, kill adb server. + Occasionally the device becomes offline. Because the device is still + reading data without realizing that the adb server has been restarted. + Test if we can bring the device online automatically now. + http://b/32952319 + """ + serialno = subprocess.check_output(self.device.adb_cmd + ['get-serialno']).strip() + # 1. Push a large file + file_path = 'tmp_large_file' + try: + fh = open(file_path, 'w') + fh.write('\0' * (100 * 1024 * 1024)) + fh.close() + subproc = subprocess.Popen(self.device.adb_cmd + ['push', file_path, '/data/local/tmp']) + time.sleep(0.1) + # 2. Kill the adb server + subprocess.check_call(self.device.adb_cmd + ['kill-server']) + subproc.terminate() + finally: + try: + os.unlink(file_path) + except: + pass + # 3. See if the device still exist. + # Sleep to wait for the adb server exit. + time.sleep(0.5) + # 4. The device should be online + self.assertEqual(self._get_device_state(serialno), 'device') + + def test_killed_when_pulling_a_large_file(self): + """ + While running adb pull with a large file, kill adb server. + Occasionally the device can't be connected. Because the device is trying to + send a message larger than what is expected by the adb server. + Test if we can bring the device online automatically now. + """ + serialno = subprocess.check_output(self.device.adb_cmd + ['get-serialno']).strip() + file_path = 'tmp_large_file' + try: + # 1. Create a large file on device. + self.device.shell(['dd', 'if=/dev/zero', 'of=/data/local/tmp/tmp_large_file', + 'bs=1000000', 'count=100']) + # 2. Pull the large file on host. + subproc = subprocess.Popen(self.device.adb_cmd + + ['pull','/data/local/tmp/tmp_large_file', file_path]) + time.sleep(0.1) + # 3. Kill the adb server + subprocess.check_call(self.device.adb_cmd + ['kill-server']) + subproc.terminate() + finally: + try: + os.unlink(file_path) + except: + pass + # 4. See if the device still exist. + # Sleep to wait for the adb server exit. + time.sleep(0.5) + self.assertEqual(self._get_device_state(serialno), 'device') + + def main(): random.seed(0) if len(adb.get_devices()) > 0: diff --git a/adb/transport.cpp b/adb/transport.cpp index 4686841ececb..cc8c1625204c 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -41,6 +42,7 @@ #include "adb_trace.h" #include "adb_utils.h" #include "diagnose_usb.h" +#include "fdevent.h" static void transport_unref(atransport *t); @@ -209,6 +211,11 @@ static void read_transport_thread(void* _t) { put_apacket(p); break; } +#if ADB_HOST + if (p->msg.command == 0) { + continue; + } +#endif } D("%s: received remote packet, sending to transport", t->serial); @@ -271,7 +278,11 @@ static void write_transport_thread(void* _t) { if (active) { D("%s: transport got packet, sending to remote", t->serial); ATRACE_NAME("write_transport write_remote"); - t->write_to_remote(p, t); + if (t->Write(p) != 0) { + D("%s: remote write failed for transport", t->serial); + put_apacket(p); + break; + } } else { D("%s: transport ignoring packet while offline", t->serial); } @@ -493,7 +504,7 @@ static void transport_registration_func(int _fd, unsigned ev, void* data) { } /* don't create transport threads for inaccessible devices */ - if (t->connection_state != kCsNoPerm) { + if (t->GetConnectionState() != kCsNoPerm) { /* initial references are the two threads */ t->ref_count = 2; @@ -538,6 +549,15 @@ void init_transport_registration(void) { transport_registration_func, 0); fdevent_set(&transport_registration_fde, FDE_READ); +#if ADB_HOST + android::base::at_quick_exit([]() { + // To avoid only writing part of a packet to a transport after exit, kick all transports. + std::lock_guard lock(transport_lock); + for (auto t : transport_list) { + t->Kick(); + } + }); +#endif } /* the fdevent select pump is single threaded */ @@ -600,7 +620,7 @@ static int qual_match(const char* to_test, const char* prefix, const char* qual, } atransport* acquire_one_transport(TransportType type, const char* serial, bool* is_ambiguous, - std::string* error_out) { + std::string* error_out, bool accept_any_state) { atransport* result = nullptr; if (serial) { @@ -615,7 +635,7 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool* std::unique_lock lock(transport_lock); for (const auto& t : transport_list) { - if (t->connection_state == kCsNoPerm) { + if (t->GetConnectionState() == kCsNoPerm) { #if ADB_HOST *error_out = UsbNoPermissionsLongHelpText(); #endif @@ -664,7 +684,7 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool* lock.unlock(); // Don't return unauthorized devices; the caller can't do anything with them. - if (result && result->connection_state == kCsUnauthorized) { + if (result && result->GetConnectionState() == kCsUnauthorized && !accept_any_state) { *error_out = "device unauthorized.\n"; char* ADB_VENDOR_KEYS = getenv("ADB_VENDOR_KEYS"); *error_out += "This adb server's $ADB_VENDOR_KEYS is "; @@ -676,7 +696,7 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool* } // Don't return offline devices; the caller can't do anything with them. - if (result && result->connection_state == kCsOffline) { + if (result && result->GetConnectionState() == kCsOffline && !accept_any_state) { *error_out = "device offline"; result = nullptr; } @@ -688,16 +708,38 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool* return result; } +int atransport::Write(apacket* p) { +#if ADB_HOST + std::lock_guard lock(write_msg_lock_); +#endif + return write_func_(p, this); +} + void atransport::Kick() { if (!kicked_) { kicked_ = true; CHECK(kick_func_ != nullptr); +#if ADB_HOST + // On host, adb server should avoid writing part of a packet, so don't + // kick a transport whiling writing a packet. + std::lock_guard lock(write_msg_lock_); +#endif kick_func_(this); } } +ConnectionState atransport::GetConnectionState() const { + return connection_state_; +} + +void atransport::SetConnectionState(ConnectionState state) { + check_main_thread(); + connection_state_ = state; +} + const std::string atransport::connection_state_name() const { - switch (connection_state) { + ConnectionState state = GetConnectionState(); + switch (state) { case kCsOffline: return "offline"; case kCsBootloader: @@ -963,10 +1005,10 @@ void kick_all_tcp_devices() { void register_usb_transport(usb_handle* usb, const char* serial, const char* devpath, unsigned writeable) { - atransport* t = new atransport(); + atransport* t = new atransport((writeable ? kCsOffline : kCsNoPerm)); D("transport: %p init'ing for usb_handle %p (sn='%s')", t, usb, serial ? serial : ""); - init_usb_transport(t, usb, (writeable ? kCsOffline : kCsNoPerm)); + init_usb_transport(t, usb); if (serial) { t->serial = strdup(serial); } @@ -987,12 +1029,13 @@ void register_usb_transport(usb_handle* usb, const char* serial, const char* dev void unregister_usb_transport(usb_handle* usb) { std::lock_guard lock(transport_lock); transport_list.remove_if( - [usb](atransport* t) { return t->usb == usb && t->connection_state == kCsNoPerm; }); + [usb](atransport* t) { return t->usb == usb && t->GetConnectionState() == kCsNoPerm; }); } int check_header(apacket* p, atransport* t) { if (p->msg.magic != (p->msg.command ^ 0xffffffff)) { - VLOG(RWX) << "check_header(): invalid magic"; + VLOG(RWX) << "check_header(): invalid magic command = " << std::hex << p->msg.command + << ", magic = " << p->msg.magic; return -1; } @@ -1020,4 +1063,11 @@ std::shared_ptr atransport::NextKey() { keys_.pop_front(); return result; } +bool atransport::SetSendConnectOnError() { + if (has_send_connect_on_error_) { + return false; + } + has_send_connect_on_error_ = true; + return true; +} #endif diff --git a/adb/transport.h b/adb/transport.h index 4d97fc78b64f..8c15d663d410 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -19,10 +19,12 @@ #include +#include #include #include #include #include +#include #include #include @@ -57,31 +59,35 @@ class atransport { // class in one go is a very large change. Given how bad our testing is, // it's better to do this piece by piece. - atransport() { + atransport(ConnectionState state = kCsOffline) : connection_state_(state) { transport_fde = {}; protocol_version = A_VERSION; max_payload = MAX_PAYLOAD; } - virtual ~atransport() {} int (*read_from_remote)(apacket* p, atransport* t) = nullptr; - int (*write_to_remote)(apacket* p, atransport* t) = nullptr; void (*close)(atransport* t) = nullptr; + + void SetWriteFunction(int (*write_func)(apacket*, atransport*)) { write_func_ = write_func; } void SetKickFunction(void (*kick_func)(atransport*)) { kick_func_ = kick_func; } bool IsKicked() { return kicked_; } + int Write(apacket* p); void Kick(); + // ConnectionState can be read by all threads, but can only be written in the main thread. + ConnectionState GetConnectionState() const; + void SetConnectionState(ConnectionState state); + int fd = -1; int transport_socket = -1; fdevent transport_fde; size_t ref_count = 0; uint32_t sync_token = 0; - ConnectionState connection_state = kCsOffline; bool online = false; TransportType type = kTransportAny; @@ -114,11 +120,13 @@ class atransport { #if ADB_HOST std::shared_ptr NextKey(); + bool SetSendConnectOnError(); #endif char token[TOKEN_SIZE] = {}; size_t failed_auth_attempts = 0; + const std::string serial_name() const { return serial ? serial : ""; } const std::string connection_state_name() const; void update_version(int version, size_t payload); @@ -157,6 +165,7 @@ class atransport { int local_port_for_emulator_ = -1; bool kicked_ = false; void (*kick_func_)(atransport*) = nullptr; + int (*write_func_)(apacket*, atransport*) = nullptr; // A set of features transmitted in the banner with the initial connection. // This is stored in the banner as 'features=feature0,feature1,etc'. @@ -167,8 +176,11 @@ class atransport { // A list of adisconnect callbacks called when the transport is kicked. std::list disconnects_; + std::atomic connection_state_; #if ADB_HOST std::deque> keys_; + std::mutex write_msg_lock_; + bool has_send_connect_on_error_ = false; #endif DISALLOW_COPY_AND_ASSIGN(atransport); @@ -181,8 +193,8 @@ class atransport { * is set to true and nullptr returned. * If no suitable transport is found, error is set and nullptr returned. */ -atransport* acquire_one_transport(TransportType type, const char* serial, - bool* is_ambiguous, std::string* error_out); +atransport* acquire_one_transport(TransportType type, const char* serial, bool* is_ambiguous, + std::string* error_out, bool accept_any_state = false); void kick_transport(atransport* t); void update_transports(void); diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp index 408f51fa18d5..3ee286a12967 100644 --- a/adb/transport_local.cpp +++ b/adb/transport_local.cpp @@ -515,12 +515,11 @@ int init_socket_transport(atransport *t, int s, int adb_port, int local) int fail = 0; t->SetKickFunction(remote_kick); + t->SetWriteFunction(remote_write); t->close = remote_close; t->read_from_remote = remote_read; - t->write_to_remote = remote_write; t->sfd = s; t->sync_token = 1; - t->connection_state = kCsOffline; t->type = kTransportLocal; #if ADB_HOST diff --git a/adb/transport_test.cpp b/adb/transport_test.cpp index 8b38e0334ebb..68689d4a6c3d 100644 --- a/adb/transport_test.cpp +++ b/adb/transport_test.cpp @@ -94,12 +94,13 @@ TEST(transport, SetFeatures) { } TEST(transport, parse_banner_no_features) { + set_main_thread(); atransport t; parse_banner("host::", &t); ASSERT_EQ(0U, t.features().size()); - ASSERT_EQ(kCsHost, t.connection_state); + ASSERT_EQ(kCsHost, t.GetConnectionState()); ASSERT_EQ(nullptr, t.product); ASSERT_EQ(nullptr, t.model); @@ -113,7 +114,7 @@ TEST(transport, parse_banner_product_features) { "host::ro.product.name=foo;ro.product.model=bar;ro.product.device=baz;"; parse_banner(banner, &t); - ASSERT_EQ(kCsHost, t.connection_state); + ASSERT_EQ(kCsHost, t.GetConnectionState()); ASSERT_EQ(0U, t.features().size()); @@ -130,7 +131,7 @@ TEST(transport, parse_banner_features) { "features=woodly,doodly"; parse_banner(banner, &t); - ASSERT_EQ(kCsHost, t.connection_state); + ASSERT_EQ(kCsHost, t.GetConnectionState()); ASSERT_EQ(2U, t.features().size()); ASSERT_TRUE(t.has_feature("woodly")); diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index 516b4f20d1c0..ce419b88d708 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -25,9 +25,115 @@ #include "adb.h" +#if ADB_HOST + +static constexpr size_t MAX_USB_BULK_PACKET_SIZE = 1024u; + +// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes +// to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html. +static int UsbReadMessage(usb_handle* h, amessage* msg) { + D("UsbReadMessage"); + char buffer[MAX_USB_BULK_PACKET_SIZE]; + int n = usb_read(h, buffer, sizeof(buffer)); + if (n == sizeof(*msg)) { + memcpy(msg, buffer, sizeof(*msg)); + } + return n; +} + +// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes +// to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html. +static int UsbReadPayload(usb_handle* h, apacket* p) { + D("UsbReadPayload"); + size_t need_size = p->msg.data_length; + size_t data_pos = 0u; + while (need_size > 0u) { + int n = 0; + if (data_pos + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) { + // Read directly to p->data. + size_t rem_size = need_size % MAX_USB_BULK_PACKET_SIZE; + size_t direct_read_size = need_size - rem_size; + if (rem_size && + data_pos + direct_read_size + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) { + direct_read_size += MAX_USB_BULK_PACKET_SIZE; + } + n = usb_read(h, &p->data[data_pos], direct_read_size); + if (n < 0) { + D("usb_read(size %zu) failed", direct_read_size); + return n; + } + } else { + // Read indirectly using a buffer. + char buffer[MAX_USB_BULK_PACKET_SIZE]; + n = usb_read(h, buffer, sizeof(buffer)); + if (n < 0) { + D("usb_read(size %zu) failed", sizeof(buffer)); + return -1; + } + size_t copy_size = std::min(static_cast(n), need_size); + D("usb read %d bytes, need %zu bytes, copy %zu bytes", n, need_size, copy_size); + memcpy(&p->data[data_pos], buffer, copy_size); + } + data_pos += n; + need_size -= std::min(static_cast(n), need_size); + } + return static_cast(data_pos); +} + +static int remote_read(apacket* p, atransport* t) { + int n = UsbReadMessage(t->usb, &p->msg); + if (n < 0) { + D("remote usb: read terminated (message)"); + return -1; + } + if (static_cast(n) != sizeof(p->msg) || check_header(p, t)) { + D("remote usb: check_header failed, skip it"); + goto err_msg; + } + if (t->GetConnectionState() == kCsOffline) { + // If we read a wrong msg header declaring a large message payload, don't read its payload. + // Otherwise we may miss true messages from the device. + if (p->msg.command != A_CNXN && p->msg.command != A_AUTH) { + goto err_msg; + } + } + if (p->msg.data_length) { + n = UsbReadPayload(t->usb, p); + if (n < 0) { + D("remote usb: terminated (data)"); + return -1; + } + if (static_cast(n) != p->msg.data_length) { + D("remote usb: read payload failed (need %u bytes, give %d bytes), skip it", + p->msg.data_length, n); + goto err_msg; + } + } + if (check_data(p)) { + D("remote usb: check_data failed, skip it"); + goto err_msg; + } + return 0; + +err_msg: + p->msg.command = 0; + if (t->GetConnectionState() == kCsOffline) { + // If the data toggle of ep_out on device and ep_in on host are not the same, we may receive + // an error message. In this case, resend one A_CNXN message to connect the device. + if (t->SetSendConnectOnError()) { + SendConnectOnHost(t); + } + } + return 0; +} + +#else + +// On Android devices, we rely on the kernel to provide buffered read. +// So we can recover automatically from EOVERFLOW. static int remote_read(apacket *p, atransport *t) { - if(usb_read(t->usb, &p->msg, sizeof(amessage))){ + if (usb_read(t->usb, &p->msg, sizeof(amessage))) { D("remote usb: read terminated (message)"); return -1; } @@ -38,7 +144,7 @@ static int remote_read(apacket *p, atransport *t) } if(p->msg.data_length) { - if(usb_read(t->usb, p->data, p->msg.data_length)){ + if (usb_read(t->usb, p->data, p->msg.data_length)) { D("remote usb: terminated (data)"); return -1; } @@ -51,17 +157,18 @@ static int remote_read(apacket *p, atransport *t) return 0; } +#endif static int remote_write(apacket *p, atransport *t) { unsigned size = p->msg.data_length; - if(usb_write(t->usb, &p->msg, sizeof(amessage))) { + if (usb_write(t->usb, &p->msg, sizeof(amessage))) { D("remote usb: 1 - write terminated"); return -1; } if(p->msg.data_length == 0) return 0; - if(usb_write(t->usb, &p->data, size)) { + if (usb_write(t->usb, &p->data, size)) { D("remote usb: 2 - write terminated"); return -1; } @@ -75,20 +182,17 @@ static void remote_close(atransport *t) t->usb = 0; } -static void remote_kick(atransport *t) -{ +static void remote_kick(atransport* t) { usb_kick(t->usb); } -void init_usb_transport(atransport *t, usb_handle *h, ConnectionState state) -{ +void init_usb_transport(atransport* t, usb_handle* h) { D("transport: usb"); t->close = remote_close; t->SetKickFunction(remote_kick); + t->SetWriteFunction(remote_write); t->read_from_remote = remote_read; - t->write_to_remote = remote_write; t->sync_token = 1; - t->connection_state = state; t->type = kTransportUsb; t->usb = h; }