Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[6.0.0] [remote/downloader] Migrate Downloader to take Credentials #16732

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2022 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.devtools.build.lib.authandtls;

import com.google.auth.Credentials;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.net.URI;
import java.util.List;
import java.util.Map;

/** Implementation of {@link Credentials} which provides a static set of credentials. */
public final class StaticCredentials extends Credentials {
public static final StaticCredentials EMPTY = new StaticCredentials(ImmutableMap.of());

private final ImmutableMap<URI, Map<String, List<String>>> credentials;

public StaticCredentials(Map<URI, Map<String, List<String>>> credentials) {
Preconditions.checkNotNull(credentials);

this.credentials = ImmutableMap.copyOf(credentials);
}

@Override
public String getAuthenticationType() {
return "static";
}

@Override
public Map<String, List<String>> getRequestMetadata(URI uri) {
Preconditions.checkNotNull(uri);

return credentials.getOrDefault(uri, ImmutableMap.of());
}

@Override
public boolean hasRequestMetadata() {
return true;
}

@Override
public boolean hasRequestMetadataOnly() {
return true;
}

@Override
public void refresh() {
// Can't refresh static credentials.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -47,7 +47,7 @@ public void setDelegate(@Nullable Downloader delegate) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -60,6 +60,6 @@ public void download(
downloader = delegate;
}
downloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.authandtls.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCacheHitEvent;
Expand Down Expand Up @@ -256,7 +257,7 @@ public Path download(
try {
downloader.download(
rewrittenUrls,
rewrittenAuthHeaders,
new StaticCredentials(rewrittenAuthHeaders),
checksum,
canonicalId,
destination,
Expand Down Expand Up @@ -337,7 +338,7 @@ public byte[] downloadAndReadOneUrl(
for (int attempt = 0; attempt <= retries; ++attempt) {
try {
return httpDownloader.downloadAndReadOneUrl(
rewrittenUrls.get(0), authHeaders, eventHandler, clientEnv);
rewrittenUrls.get(0), new StaticCredentials(authHeaders), eventHandler, clientEnv);
} catch (ContentLengthMismatchException e) {
if (attempt == retries) {
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
Expand All @@ -33,7 +33,7 @@ public interface Downloader {
* caller is responsible for cleaning up outputs of failed downloads.
*
* @param urls list of mirror URLs with identical content
* @param authHeaders map of authentication headers per URL
* @param credentials credentials to use when connecting to URLs
* @param checksum valid checksum which is checked, or absent to disable
* @param output path to the destination file to write
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
Expand All @@ -42,7 +42,7 @@ public interface Downloader {
*/
void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
import com.google.devtools.build.lib.authandtls.StaticCredentials;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.EventHandler;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
Expand Down Expand Up @@ -74,7 +75,7 @@ final class HttpConnectorMultiplexer {
}

public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOException {
return connect(url, checksum, ImmutableMap.of(), Optional.absent());
return connect(url, checksum, StaticCredentials.EMPTY, Optional.absent());
}

/**
Expand All @@ -87,25 +88,22 @@ public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOExcepti
*
* @param url the URL to conenct to. can be: file, http, or https
* @param checksum checksum lazily checked on entire payload, or empty to disable
* @param authHeaders the authentication headers
* @param credentials the credentials
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
* @return an {@link InputStream} of response payload
* @throws IOException if all mirrors are down and contains suppressed exception of each attempt
* @throws InterruptedIOException if current thread is being cast into oblivion
* @throws IllegalArgumentException if {@code urls} is empty or has an unsupported protocol
*/
public HttpStream connect(
URL url,
Optional<Checksum> checksum,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<String> type)
URL url, Optional<Checksum> checksum, Credentials credentials, Optional<String> type)
throws IOException {
Preconditions.checkArgument(HttpUtils.isUrlSupportedByDownloader(url));
if (Thread.interrupted()) {
throw new InterruptedIOException();
}
Function<URL, ImmutableMap<String, List<String>>> headerFunction =
getHeaderFunction(REQUEST_HEADERS, authHeaders);
getHeaderFunction(REQUEST_HEADERS, credentials);
URLConnection connection = connector.connect(url, headerFunction);
return httpStreamFactory.create(
connection,
Expand All @@ -127,21 +125,20 @@ public HttpStream connect(

@VisibleForTesting
static Function<URL, ImmutableMap<String, List<String>>> getHeaderFunction(
Map<String, List<String>> baseHeaders,
Map<URI, Map<String, List<String>>> additionalHeaders) {
Map<String, List<String>> baseHeaders, Credentials credentials) {
Preconditions.checkNotNull(baseHeaders);
Preconditions.checkNotNull(credentials);

return url -> {
ImmutableMap<String, List<String>> headers = ImmutableMap.copyOf(baseHeaders);
Map<String, List<String>> headers = new HashMap<>(baseHeaders);
try {
if (additionalHeaders.containsKey(url.toURI())) {
Map<String, List<String>> newHeaders = new HashMap<>(headers);
newHeaders.putAll(additionalHeaders.get(url.toURI()));
headers = ImmutableMap.copyOf(newHeaders);
}
} catch (URISyntaxException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), still try
// to do the connection, not adding authentication information as we cannot look it up.
headers.putAll(credentials.getRequestMetadata(url.toURI()));
} catch (URISyntaxException | IOException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), or fetching
// credentials fails for any other reason, still try to do the connection, not adding
// authentication information as we cannot look it up.
}
return headers;
return ImmutableMap.copyOf(headers);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
Expand All @@ -31,7 +32,6 @@
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -63,7 +63,7 @@ public void setTimeoutScaling(float timeoutScaling) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -82,7 +82,7 @@ public void download(
for (URL url : urls) {
SEMAPHORE.acquire();

try (HttpStream payload = multiplexer.connect(url, checksum, authHeaders, type);
try (HttpStream payload = multiplexer.connect(url, checksum, credentials, type);
OutputStream out = destination.getOutputStream()) {
try {
ByteStreams.copy(payload, out);
Expand Down Expand Up @@ -132,7 +132,7 @@ public void download(
/** Downloads the contents of one URL and reads it into a byte array. */
public byte[] downloadAndReadOneUrl(
URL url,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
ExtendedEventHandler eventHandler,
Map<String, String> clientEnv)
throws IOException, InterruptedException {
Expand All @@ -141,7 +141,7 @@ public byte[] downloadAndReadOneUrl(
ByteArrayOutputStream out = new ByteArrayOutputStream();
SEMAPHORE.acquire();
try (HttpStream payload =
multiplexer.connect(url, Optional.absent(), authHeaders, Optional.absent())) {
multiplexer.connect(url, Optional.absent(), credentials, Optional.absent())) {
ByteStreams.copy(payload, out);
} catch (SocketTimeoutException e) {
// SocketTimeoutExceptions are InterruptedIOExceptions; however they do not signify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/remote/options",
"//src/main/java/com/google/devtools/build/lib/remote/util",
"//src/main/java/com/google/devtools/build/lib/vfs",
"//third_party:auth",
"//third_party:guava",
"//third_party:jsr305",
"//third_party/grpc-java:grpc-jar",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import build.bazel.remote.asset.v1.Qualifier;
import build.bazel.remote.execution.v2.Digest;
import build.bazel.remote.execution.v2.RequestMetadata;
import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.devtools.build.lib.bazel.repository.downloader.Checksum;
Expand All @@ -41,7 +42,6 @@
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -110,7 +110,7 @@ public void close() {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
com.google.common.base.Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand Down Expand Up @@ -154,7 +154,7 @@ public void download(
eventHandler.handle(
Event.warn("Remote Cache: " + Utils.grpcAwareErrorMessage(e, verboseFailures)));
fallbackDownloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.authandtls.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.downloader.RetryingInputStream.Reconnector;
import com.google.devtools.build.lib.events.EventHandler;
Expand Down Expand Up @@ -163,7 +164,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap.of("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA==")));

Function<URL, ImmutableMap<String, List<String>>> headerFunction =
HttpConnectorMultiplexer.getHeaderFunction(baseHeaders, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
baseHeaders, new StaticCredentials(additionalHeaders));

// Unreleated URL
assertThat(headerFunction.apply(new URL("http://example.org/some/path/file.txt")))
Expand Down Expand Up @@ -215,7 +217,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap<String, List<String>> annonAuth =
ImmutableMap.of("Authentication", ImmutableList.of("YW5vbnltb3VzOmZvb0BleGFtcGxlLm9yZw=="));
Function<URL, ImmutableMap<String, List<String>>> combinedHeaders =
HttpConnectorMultiplexer.getHeaderFunction(annonAuth, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
annonAuth, new StaticCredentials(additionalHeaders));
assertThat(combinedHeaders.apply(new URL("http://hosting.example.com/user/foo/file.txt")))
.containsExactly("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA=="));
assertThat(combinedHeaders.apply(new URL("http://unreleated.example.org/user/foo/file.txt")))
Expand Down
Loading