diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Checksum.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Checksum.java new file mode 100644 index 00000000000000..07ca16d614de21 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/Checksum.java @@ -0,0 +1,50 @@ +// Copyright 2019 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.bazel.repository.downloader; + +import com.google.common.hash.HashCode; +import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType; + +/** The content checksum for an HTTP download, which knows its own type. */ +public class Checksum { + private final KeyType keyType; + private final HashCode hashCode; + + private Checksum(KeyType keyType, HashCode hashCode) { + this.keyType = keyType; + this.hashCode = hashCode; + } + + /** Constructs a new Checksum for a given key type and hash, in hex format. */ + public static Checksum fromString(KeyType keyType, String hash) { + if (!keyType.isValid(hash)) { + throw new IllegalArgumentException("Invalid " + keyType + " checksum '" + hash + "'"); + } + return new Checksum(keyType, HashCode.fromString(hash)); + } + + @Override + public String toString() { + return hashCode.toString(); + } + + public HashCode getHashCode() { + return hashCode; + } + + public KeyType getKeyType() { + return keyType; + } +} diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStream.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStream.java index ee7b1bd0c493f1..51695a52187b00 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStream.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStream.java @@ -15,7 +15,6 @@ package com.google.devtools.build.lib.bazel.repository.downloader; import com.google.common.hash.HashCode; -import com.google.common.hash.HashFunction; import com.google.common.hash.Hasher; import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadCompatible; import java.io.IOException; @@ -41,11 +40,10 @@ final class HashInputStream extends InputStream { private final HashCode code; @Nullable private volatile HashCode actual; - HashInputStream( - @WillCloseWhenClosed InputStream delegate, HashFunction function, HashCode code) { + HashInputStream(@WillCloseWhenClosed InputStream delegate, Checksum checksum) { this.delegate = delegate; - this.hasher = function.newHasher(); - this.code = code; + this.hasher = checksum.getKeyType().newHasher(); + this.code = checksum.getHashCode(); } @Override diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java index 36328aecf5a4cc..b5bcabf4473b7b 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java @@ -16,7 +16,7 @@ import static com.google.common.collect.ImmutableSortedSet.toImmutableSortedSet; -import com.google.common.base.Preconditions; +import com.google.common.base.Optional; import com.google.common.base.Predicates; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Ordering; @@ -95,8 +95,8 @@ final class HttpConnectorMultiplexer { this.sleeper = sleeper; } - public HttpStream connect(List urls, String sha256) throws IOException { - return connect(urls, sha256, ImmutableMap.>of()); + public HttpStream connect(List urls, Optional checksum) throws IOException { + return connect(urls, checksum, ImmutableMap.>of()); } /** @@ -116,22 +116,22 @@ public HttpStream connect(List urls, String sha256) throws IOException { * and block until the connection can be renegotiated transparently right where it left off. * * @param urls mirrors by preference; each URL can be: file, http, or https - * @param sha256 hex checksum lazily checked on entire payload, or empty to disable + * @param checksum checksum lazily checked on entire payload, or empty to disable * @return an {@link InputStream} of response payload * @throws IOException if all mirrors are down and contains suppressed exception of each attempt * @throws InterruptedIOException if current thread is being cast into oblivion * @throws IllegalArgumentException if {@code urls} is empty or has an unsupported protocol */ public HttpStream connect( - List urls, String sha256, Map> authHeaders) throws IOException { - Preconditions.checkNotNull(sha256); + List urls, Optional checksum, Map> authHeaders) + throws IOException { HttpUtils.checkUrlsArgument(urls); if (Thread.interrupted()) { throw new InterruptedIOException(); } // If there's only one URL then there's no need for us to run all our fancy thread stuff. if (urls.size() == 1) { - return establishConnection(urls.get(0), sha256, authHeaders); + return establishConnection(urls.get(0), checksum, authHeaders); } MutexConditionSharedMemory context = new MutexConditionSharedMemory(); // The parent thread always holds the lock except when released by wait(). @@ -140,7 +140,7 @@ public HttpStream connect( long now = clock.currentTimeMillis(); long startAtTime = now; for (URL url : urls) { - context.jobs.add(new WorkItem(url, sha256, startAtTime, authHeaders)); + context.jobs.add(new WorkItem(url, checksum, startAtTime, authHeaders)); startAtTime += FAILOVER_DELAY_MS; } // Create the worker thread pool. @@ -210,13 +210,17 @@ private static class MutexConditionSharedMemory { private static class WorkItem { final URL url; - final String sha256; + final Optional checksum; final long startAtTime; final Map> authHeaders; - WorkItem(URL url, String sha256, long startAtTime, Map> authHeaders) { + WorkItem( + URL url, + Optional checksum, + long startAtTime, + Map> authHeaders) { this.url = url; - this.sha256 = sha256; + this.checksum = checksum; this.startAtTime = startAtTime; this.authHeaders = authHeaders; } @@ -263,7 +267,7 @@ public void run() { // Now we're actually going to attempt to connect to the remote server. HttpStream result; try { - result = establishConnection(work.url, work.sha256, work.authHeaders); + result = establishConnection(work.url, work.checksum, work.authHeaders); } catch (InterruptedIOException e) { // The parent thread got its result from another thread and killed this one. synchronized (context) { @@ -307,7 +311,7 @@ private void tellParentThreadWeAreDone() { } private HttpStream establishConnection( - final URL url, String sha256, Map> additionalHeaders) + final URL url, Optional checksum, Map> additionalHeaders) throws IOException { ImmutableMap headers = REQUEST_HEADERS; try { @@ -327,7 +331,7 @@ private HttpStream establishConnection( return httpStreamFactory.create( connection, url, - sha256, + checksum, new Reconnector() { @Override public URLConnection connect(Throwable cause, ImmutableMap extraHeaders) diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java index d9fad055bda317..5efe7b52d4975b 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java @@ -72,12 +72,12 @@ public void setTimeoutScaling(float timeoutScaling) { /** * Downloads file to disk and returns path. * - *

If the SHA256 checksum and path to the repository cache is specified, attempt to load the - * file from the {@link RepositoryCache}. If it doesn't exist, proceed to download the file and - * load it into the cache prior to returning the value. + *

If the checksum and path to the repository cache is specified, attempt to load the file from + * the {@link RepositoryCache}. If it doesn't exist, proceed to download the file and load it into + * the cache prior to returning the value. * * @param urls list of mirror URLs with identical content - * @param sha256 valid SHA256 hex checksum string which is checked, or empty to disable + * @param checksum valid checksum which is checked, or empty to disable * @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this * @param output destination filename if {@code type} is absent, otherwise output directory * @param eventHandler CLI progress reporter @@ -91,7 +91,7 @@ public void setTimeoutScaling(float timeoutScaling) { public Path download( List urls, Map> authHeaders, - String sha256, + Optional checksum, String canonicalId, Optional type, Path output, @@ -116,14 +116,15 @@ public Path download( } Path destination = getDownloadDestination(mainUrl, type, output); - // Is set to true if the value should be cached by the sha256 value provided - boolean isCachingByProvidedSha256 = false; + // Is set to true if the value should be cached by the checksum value provided + boolean isCachingByProvidedChecksum = false; - if (!sha256.isEmpty()) { + if (checksum.isPresent()) { + String cacheKey = checksum.get().toString(); + KeyType cacheKeyType = checksum.get().getKeyType(); try { - String currentSha256 = - RepositoryCache.getChecksum(KeyType.SHA256, destination); - if (currentSha256.equals(sha256)) { + String currentChecksum = RepositoryCache.getChecksum(cacheKeyType, destination); + if (currentChecksum.equals(cacheKey)) { // No need to download. return destination; } @@ -132,14 +133,14 @@ public Path download( } if (repositoryCache.isEnabled()) { - isCachingByProvidedSha256 = true; + isCachingByProvidedChecksum = true; try { Path cachedDestination = - repositoryCache.get(sha256, destination, KeyType.SHA256, canonicalId); + repositoryCache.get(cacheKey, destination, cacheKeyType, canonicalId); if (cachedDestination != null) { // Cache hit! - eventHandler.post(new RepositoryCacheHitEvent(repo, sha256, mainUrl)); + eventHandler.post(new RepositoryCacheHitEvent(repo, cacheKey, mainUrl)); return cachedDestination; } } catch (IOException e) { @@ -163,16 +164,16 @@ public Path download( boolean match = false; Path candidate = dir.getRelative(destination.getBaseName()); try { - match = RepositoryCache.getChecksum(KeyType.SHA256, candidate).equals(sha256); + match = RepositoryCache.getChecksum(cacheKeyType, candidate).equals(cacheKey); } catch (IOException e) { // Not finding anything in a distdir is a normal case, so handle it absolutely // quietly. In fact, it is not uncommon to specify a whole list of dist dirs, // with the asumption that only one will contain an entry. } if (match) { - if (isCachingByProvidedSha256) { + if (isCachingByProvidedChecksum) { try { - repositoryCache.put(sha256, candidate, KeyType.SHA256, canonicalId); + repositoryCache.put(cacheKey, candidate, cacheKeyType, canonicalId); } catch (IOException e) { eventHandler.handle( Event.warn("Failed to copy " + candidate + " to repository cache: " + e)); @@ -201,7 +202,7 @@ public Path download( // Connect to the best mirror and download the file, while reporting progress to the CLI. semaphore.acquire(); boolean success = false; - try (HttpStream payload = multiplexer.connect(urls, sha256, authHeaders); + try (HttpStream payload = multiplexer.connect(urls, checksum, authHeaders); OutputStream out = destination.getOutputStream()) { ByteStreams.copy(payload, out); success = true; @@ -215,8 +216,9 @@ public Path download( eventHandler.post(new FetchEvent(urls.get(0).toString(), success)); } - if (isCachingByProvidedSha256) { - repositoryCache.put(sha256, destination, KeyType.SHA256, canonicalId); + if (isCachingByProvidedChecksum) { + repositoryCache.put( + checksum.get().toString(), destination, checksum.get().getKeyType(), canonicalId); } else if (repositoryCache.isEnabled()) { String newSha256 = repositoryCache.put(destination, KeyType.SHA256, canonicalId); eventHandler.handle(Event.info("SHA256 (" + urls.get(0) + ") = " + newSha256)); diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStream.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStream.java index 4921b1504fe436..718ed5ba420753 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStream.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStream.java @@ -14,12 +14,11 @@ package com.google.devtools.build.lib.bazel.repository.downloader; +import com.google.common.base.Optional; import com.google.common.base.Splitter; import com.google.common.base.Strings; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import com.google.common.hash.HashCode; -import com.google.common.hash.Hashing; import com.google.common.io.ByteStreams; import com.google.devtools.build.lib.bazel.repository.downloader.RetryingInputStream.Reconnector; import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadCompatible; @@ -62,9 +61,9 @@ static class Factory { HttpStream create( @WillCloseWhenClosed URLConnection connection, URL originalUrl, - String sha256, + Optional checksum, Reconnector reconnector) - throws IOException { + throws IOException { InputStream stream = new InterruptibleInputStream(connection.getInputStream()); try { // If server supports range requests, we can retry on read errors. See RFC7233 ยง 2.3. @@ -89,8 +88,8 @@ HttpStream create( stream = new GZIPInputStream(stream, GZIP_BUFFER_BYTES); } - if (!sha256.isEmpty()) { - stream = new HashInputStream(stream, Hashing.sha256(), HashCode.fromString(sha256)); + if (checksum.isPresent()) { + stream = new HashInputStream(stream, checksum.get()); if (retrier != null) { retrier.disabled = true; } diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java index 25b521a7afc6c9..2e10b6ae8482d8 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java @@ -26,6 +26,7 @@ import com.google.devtools.build.lib.bazel.repository.DecompressorValue; import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache; import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType; +import com.google.devtools.build.lib.bazel.repository.downloader.Checksum; import com.google.devtools.build.lib.bazel.repository.downloader.HttpDownloader; import com.google.devtools.build.lib.bazel.repository.downloader.HttpUtils; import com.google.devtools.build.lib.cmdline.Label; @@ -474,6 +475,12 @@ public StructImpl download( warnAboutSha256Error(urls, sha256); sha256 = ""; } + Optional checksum; + if (sha256.isEmpty()) { + checksum = Optional.absent(); + } else { + checksum = Optional.of(Checksum.fromString(KeyType.SHA256, sha256)); + } SkylarkPath outputPath = getPath("download()", output); WorkspaceRuleEvent w = WorkspaceRuleEvent.newDownloadEvent( @@ -487,7 +494,7 @@ public StructImpl download( httpDownloader.download( urls, authHeaders, - sha256, + checksum, canonicalId, Optional.absent(), outputPath.getPath(), @@ -579,6 +586,12 @@ public StructImpl downloadAndExtract( warnAboutSha256Error(urls, sha256); sha256 = ""; } + Optional checksum; + if (sha256.isEmpty()) { + checksum = Optional.absent(); + } else { + checksum = Optional.of(Checksum.fromString(KeyType.SHA256, sha256)); + } WorkspaceRuleEvent w = WorkspaceRuleEvent.newDownloadAndExtractEvent( @@ -601,7 +614,7 @@ public StructImpl downloadAndExtract( httpDownloader.download( urls, authHeaders, - sha256, + checksum, canonicalId, Optional.of(type), outputPath.getPath(), diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/BUILD b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/BUILD index cdd037438259ab..e5873d4a0dc788 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/BUILD +++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/BUILD @@ -19,6 +19,7 @@ java_test( deps = [ "//src/main/java/com/google/devtools/build/lib:events", "//src/main/java/com/google/devtools/build/lib:util", + "//src/main/java/com/google/devtools/build/lib/bazel/repository/cache", "//src/main/java/com/google/devtools/build/lib/bazel/repository/downloader", "//src/test/java/com/google/devtools/build/lib:foundations_testutil", "//src/test/java/com/google/devtools/build/lib:test_runner", diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStreamTest.java b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStreamTest.java index 3c97cf30f727a6..9c5d5ae6ffb8c7 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStreamTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HashInputStreamTest.java @@ -17,9 +17,8 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; -import com.google.common.hash.HashCode; -import com.google.common.hash.Hashing; import com.google.common.io.CharStreams; +import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStreamReader; @@ -43,8 +42,7 @@ public void validChecksum_readsOk() throws Exception { new InputStreamReader( new HashInputStream( new ByteArrayInputStream("hello".getBytes(UTF_8)), - Hashing.sha1(), - HashCode.fromString("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d")), + Checksum.fromString(KeyType.SHA1, "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d")), UTF_8)) { assertThat(CharStreams.toString(reader)).isEqualTo("hello"); } @@ -58,8 +56,7 @@ public void badChecksum_throwsIOException() throws Exception { new InputStreamReader( new HashInputStream( new ByteArrayInputStream("hello".getBytes(UTF_8)), - Hashing.sha1(), - HashCode.fromString("0000000000000000000000000000000000000000")), + Checksum.fromString(KeyType.SHA1, "0000000000000000000000000000000000000000")), UTF_8)) { assertThat(CharStreams.toString(reader)) .isNull(); // Only here to make @CheckReturnValue happy. diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerIntegrationTest.java b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerIntegrationTest.java index 4490bb0ad01e9c..f5dccf879e797b 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerIntegrationTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerIntegrationTest.java @@ -26,7 +26,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; +import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType; import com.google.devtools.build.lib.events.ExtendedEventHandler; import com.google.devtools.build.lib.testutil.ManualClock; import com.google.devtools.build.lib.util.Sleeper; @@ -79,6 +81,11 @@ public class HttpConnectorMultiplexerIntegrationTest { private final HttpConnectorMultiplexer multiplexer = new HttpConnectorMultiplexer(eventHandler, connector, httpStreamFactory, clock, sleeper); + private static final Optional HELLO_SHA256 = + Optional.of( + Checksum.fromString( + KeyType.SHA256, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824")); + @Before public void before() throws Exception { when(proxyHelper.createProxyIfNeeded(any(URL.class))).thenReturn(Proxy.NO_PROXY); @@ -121,11 +128,11 @@ public Object call() throws Exception { phaser.arriveAndAwaitAdvance(); phaser.arriveAndDeregister(); try (HttpStream stream = - multiplexer.connect( - ImmutableList.of( - new URL(String.format("http://localhost:%d", server1.getLocalPort())), - new URL(String.format("http://localhost:%d", server2.getLocalPort()))), - "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824")) { + multiplexer.connect( + ImmutableList.of( + new URL(String.format("http://localhost:%d", server1.getLocalPort())), + new URL(String.format("http://localhost:%d", server2.getLocalPort()))), + HELLO_SHA256)) { assertThat(toByteArray(stream)).isEqualTo("hello".getBytes(US_ASCII)); } } @@ -186,11 +193,11 @@ public Object call() throws Exception { } }); try (HttpStream stream = - multiplexer.connect( - ImmutableList.of( - new URL(String.format("http://localhost:%d", server1.getLocalPort())), - new URL(String.format("http://localhost:%d", server2.getLocalPort()))), - "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824")) { + multiplexer.connect( + ImmutableList.of( + new URL(String.format("http://localhost:%d", server1.getLocalPort())), + new URL(String.format("http://localhost:%d", server2.getLocalPort()))), + HELLO_SHA256)) { assertThat(toByteArray(stream)).isEqualTo("hello".getBytes(US_ASCII)); } } @@ -231,7 +238,7 @@ public Object call() throws Exception { new URL(String.format("http://localhost:%d", server1.getLocalPort())), new URL(String.format("http://localhost:%d", server2.getLocalPort())), new URL(String.format("http://localhost:%d", server3.getLocalPort()))), - "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9825"); + HELLO_SHA256); } } } diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerTest.java b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerTest.java index a93e50f32214ad..511d08ca2f17f2 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexerTest.java @@ -22,7 +22,6 @@ import static java.util.Arrays.asList; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; @@ -33,8 +32,10 @@ import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; +import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType; import com.google.devtools.build.lib.bazel.repository.downloader.RetryingInputStream.Reconnector; import com.google.devtools.build.lib.events.EventHandler; import com.google.devtools.build.lib.testutil.ManualClock; @@ -69,6 +70,11 @@ public class HttpConnectorMultiplexerTest { private static final byte[] data2 = "second".getBytes(UTF_8); private static final byte[] data3 = "third".getBytes(UTF_8); + private static final Optional DUMMY_CHECKSUM = + Optional.of( + Checksum.fromString( + KeyType.SHA256, "abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd")); + @Rule public final ExpectedException thrown = ExpectedException.none(); @@ -94,47 +100,48 @@ public void before() throws Exception { when(connector.connect(eq(URL1), any(ImmutableMap.class))).thenReturn(connection1); when(connector.connect(eq(URL2), any(ImmutableMap.class))).thenReturn(connection2); when(connector.connect(eq(URL3), any(ImmutableMap.class))).thenReturn(connection3); - when(streamFactory - .create(same(connection1), any(URL.class), anyString(), any(Reconnector.class))) + when(streamFactory.create( + same(connection1), any(URL.class), any(Optional.class), any(Reconnector.class))) .thenReturn(stream1); - when(streamFactory - .create(same(connection2), any(URL.class), anyString(), any(Reconnector.class))) + when(streamFactory.create( + same(connection2), any(URL.class), any(Optional.class), any(Reconnector.class))) .thenReturn(stream2); - when(streamFactory - .create(same(connection3), any(URL.class), anyString(), any(Reconnector.class))) + when(streamFactory.create( + same(connection3), any(URL.class), any(Optional.class), any(Reconnector.class))) .thenReturn(stream3); } @Test public void emptyList_throwsIae() throws Exception { thrown.expect(IllegalArgumentException.class); - multiplexer.connect(ImmutableList.of(), ""); + multiplexer.connect(ImmutableList.of(), null); } @Test public void ftpUrl_throwsIae() throws Exception { thrown.expect(IllegalArgumentException.class); - multiplexer.connect(asList(new URL("ftp://lol.example")), ""); + multiplexer.connect(asList(new URL("ftp://lol.example")), null); } @Test public void threadIsInterrupted_throwsIeProntoAndDoesNothingElse() throws Exception { final AtomicBoolean wasInterrupted = new AtomicBoolean(true); - Thread task = new Thread( - new Runnable() { - @Override - public void run() { - Thread.currentThread().interrupt(); - try { - multiplexer.connect(asList(new URL("http://lol.example")), ""); - } catch (InterruptedIOException ignored) { - return; - } catch (Exception ignored) { - // ignored - } - wasInterrupted.set(false); - } - }); + Thread task = + new Thread( + new Runnable() { + @Override + public void run() { + Thread.currentThread().interrupt(); + try { + multiplexer.connect(asList(new URL("http://lol.example")), null); + } catch (InterruptedIOException ignored) { + return; + } catch (Exception ignored) { + // ignored + } + wasInterrupted.set(false); + } + }); task.start(); task.join(); assertThat(wasInterrupted.get()).isTrue(); @@ -143,10 +150,11 @@ public void run() { @Test public void singleUrl_justCallsConnector() throws Exception { - assertThat(toByteArray(multiplexer.connect(asList(URL1), "abc"))).isEqualTo(data1); + assertThat(toByteArray(multiplexer.connect(asList(URL1), DUMMY_CHECKSUM))).isEqualTo(data1); verify(connector).connect(eq(URL1), any(ImmutableMap.class)); verify(streamFactory) - .create(any(URLConnection.class), any(URL.class), eq("abc"), any(Reconnector.class)); + .create( + any(URLConnection.class), any(URL.class), eq(DUMMY_CHECKSUM), any(Reconnector.class)); verifyNoMoreInteractions(sleeper, connector, streamFactory); } @@ -154,7 +162,7 @@ public void singleUrl_justCallsConnector() throws Exception { public void multipleUrlsFail_throwsIOException() throws Exception { when(connector.connect(any(URL.class), any(ImmutableMap.class))).thenThrow(new IOException()); IOException e = - assertThrows(IOException.class, () -> multiplexer.connect(asList(URL1, URL2, URL3), "")); + assertThrows(IOException.class, () -> multiplexer.connect(asList(URL1, URL2, URL3), null)); assertThat(e).hasMessageThat().contains("All mirrors are down"); verify(connector, times(3)).connect(any(URL.class), any(ImmutableMap.class)); verify(sleeper, times(2)).sleepMillis(anyLong()); @@ -172,12 +180,14 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } }).when(sleeper).sleepMillis(anyLong()); when(connector.connect(eq(URL1), any(ImmutableMap.class))).thenThrow(new IOException()); - assertThat(toByteArray(multiplexer.connect(asList(URL1, URL2), "abc"))).isEqualTo(data2); + assertThat(toByteArray(multiplexer.connect(asList(URL1, URL2), DUMMY_CHECKSUM))) + .isEqualTo(data2); assertThat(clock.currentTimeMillis()).isEqualTo(1000L); verify(connector).connect(eq(URL1), any(ImmutableMap.class)); verify(connector).connect(eq(URL2), any(ImmutableMap.class)); verify(streamFactory) - .create(any(URLConnection.class), any(URL.class), eq("abc"), any(Reconnector.class)); + .create( + any(URLConnection.class), any(URL.class), eq(DUMMY_CHECKSUM), any(Reconnector.class)); verify(sleeper).sleepMillis(anyLong()); verifyNoMoreInteractions(sleeper, connector, streamFactory); } @@ -204,7 +214,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable { return null; } }).when(sleeper).sleepMillis(anyLong()); - assertThat(toByteArray(multiplexer.connect(asList(URL1, URL2), "abc"))).isEqualTo(data1); + assertThat(toByteArray(multiplexer.connect(asList(URL1, URL2), DUMMY_CHECKSUM))) + .isEqualTo(data1); assertThat(wasInterrupted.get()).isTrue(); } @@ -234,20 +245,21 @@ public URLConnection answer(InvocationOnMock invocation) throws Throwable { throw new RuntimeException(); } }); - Thread task = new Thread( - new Runnable() { - @Override - public void run() { - try { - multiplexer.connect(asList(URL1, URL2), ""); - } catch (InterruptedIOException ignored) { - return; - } catch (Exception ignored) { - // ignored - } - wasInterrupted3.set(false); - } - }); + Thread task = + new Thread( + new Runnable() { + @Override + public void run() { + try { + multiplexer.connect(asList(URL1, URL2), null); + } catch (InterruptedIOException ignored) { + return; + } catch (Exception ignored) { + // ignored + } + wasInterrupted3.set(false); + } + }); task.start(); barrier.await(); task.interrupt(); diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStreamTest.java b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStreamTest.java index 6af4caf3b8adfe..011ba7ad7c2215 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStreamTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpStreamTest.java @@ -23,8 +23,10 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.common.base.Optional; import com.google.common.hash.Hashing; import com.google.common.io.ByteStreams; +import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType; import com.google.devtools.build.lib.bazel.repository.downloader.RetryingInputStream.Reconnector; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -54,10 +56,14 @@ public class HttpStreamTest { private static final Random randoCalrissian = new Random(); private static final byte[] data = "hello".getBytes(UTF_8); - private static final String GOOD_CHECKSUM = - "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"; - private static final String BAD_CHECKSUM = - "0000000000000000000000000000000000000000000000000000000000000000"; + private static final Optional GOOD_CHECKSUM = + Optional.of( + Checksum.fromString( + KeyType.SHA256, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824")); + private static final Optional BAD_CHECKSUM = + Optional.of( + Checksum.fromString( + KeyType.SHA256, "0000000000000000000000000000000000000000000000000000000000000000")); private static final URL AURL = makeUrl("http://doodle.example"); @Rule @@ -86,7 +92,8 @@ public InputStream answer(InvocationOnMock invocation) throws Throwable { @Test public void noChecksum_readsOk() throws Exception { - try (HttpStream stream = streamFactory.create(connection, AURL, "", reconnector)) { + try (HttpStream stream = + streamFactory.create(connection, AURL, Optional.absent(), reconnector)) { assertThat(toByteArray(stream)).isEqualTo(data); } } @@ -112,8 +119,13 @@ public void bigDataWithValidChecksum_readsOk() throws Exception { randoCalrissian.nextBytes(bigData); when(connection.getInputStream()).thenReturn(new ByteArrayInputStream(bigData)); try (HttpStream stream = - streamFactory.create( - connection, AURL, Hashing.sha256().hashBytes(bigData).toString(), reconnector)) { + streamFactory.create( + connection, + AURL, + Optional.of( + Checksum.fromString( + KeyType.SHA256, Hashing.sha256().hashBytes(bigData).toString())), + reconnector)) { assertThat(toByteArray(stream)).isEqualTo(bigData); } } @@ -137,7 +149,7 @@ public void httpServerSaidGzippedButNotGzipped_throwsZipExceptionInCreate() thro when(connection.getURL()).thenReturn(AURL); when(connection.getContentEncoding()).thenReturn("gzip"); thrown.expect(ZipException.class); - streamFactory.create(connection, AURL, "", reconnector); + streamFactory.create(connection, AURL, Optional.absent(), reconnector); } @Test @@ -145,7 +157,8 @@ public void javascriptGzippedInTransit_automaticallyGunzips() throws Exception { when(connection.getURL()).thenReturn(AURL); when(connection.getContentEncoding()).thenReturn("x-gzip"); when(connection.getInputStream()).thenReturn(new ByteArrayInputStream(gzipData(data))); - try (HttpStream stream = streamFactory.create(connection, AURL, "", reconnector)) { + try (HttpStream stream = + streamFactory.create(connection, AURL, Optional.absent(), reconnector)) { assertThat(toByteArray(stream)).isEqualTo(data); } } @@ -156,7 +169,8 @@ public void serverSaysTarballPathIsGzipped_doesntAutomaticallyGunzip() throws Ex when(connection.getURL()).thenReturn(new URL("http://doodle.example/foo.tar.gz")); when(connection.getContentEncoding()).thenReturn("gzip"); when(connection.getInputStream()).thenReturn(new ByteArrayInputStream(gzData)); - try (HttpStream stream = streamFactory.create(connection, AURL, "", reconnector)) { + try (HttpStream stream = + streamFactory.create(connection, AURL, Optional.absent(), reconnector)) { assertThat(toByteArray(stream)).isEqualTo(gzData); } } @@ -164,22 +178,24 @@ public void serverSaysTarballPathIsGzipped_doesntAutomaticallyGunzip() throws Ex @Test public void threadInterrupted_haltsReadingAndThrowsInterrupt() throws Exception { final AtomicBoolean wasInterrupted = new AtomicBoolean(); - Thread thread = new Thread( - new Runnable() { - @Override - public void run() { - try (HttpStream stream = streamFactory.create(connection, AURL, "", reconnector)) { - stream.read(); - Thread.currentThread().interrupt(); - stream.read(); - fail(); - } catch (InterruptedIOException expected) { - wasInterrupted.set(true); - } catch (IOException ignored) { - // ignored - } - } - }); + Thread thread = + new Thread( + new Runnable() { + @Override + public void run() { + try (HttpStream stream = + streamFactory.create(connection, AURL, Optional.absent(), reconnector)) { + stream.read(); + Thread.currentThread().interrupt(); + stream.read(); + fail(); + } catch (InterruptedIOException expected) { + wasInterrupted.set(true); + } catch (IOException ignored) { + // ignored + } + } + }); thread.start(); thread.join(); assertThat(wasInterrupted.get()).isTrue();