Skip to content

Commit

Permalink
Introduce maxCallDepth param to terminate requests that have hopped t…
Browse files Browse the repository at this point in the history
…oo much (linkerd#2300)

This PR introduces `maxCallDepth` param that defaults to 1000 at the moment.  This parameter, combined with a `MaxCallDepthFilter` allows us to detect requests that have done too many hops and return a proper `400` status and message.

Fixes linkerd#1411 

Signed-off-by: Zahari Dichev <zaharidichev@gmail.com>
  • Loading branch information
zaharidichev authored and adleong committed Jul 25, 2019
1 parent 9026195 commit d87f193
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ object Headers {
val Path = ":path"
val Scheme = ":scheme"
val Status = ":status"
val Via = "via"

def apply(pairs: Seq[(String, String)]): Headers =
new Impl(pairs)
Expand Down
1 change: 1 addition & 0 deletions linkerd/docs/protocol-h2.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ initialStreamWindowBytes | 64KB | Configures `SETTINGS_INITIAL_WINDOW_SIZE` on s
maxConcurrentStreamsPerConnection | 1000 | Configures `SETTINGS_MAX_CONCURRENT_STREAMS` on new streams.
maxFrameBytes | 16KB | Configures `SETTINGS_MAX_FRAME_SIZE` on new streams.
maxHeaderListByts | none | Configures `SETTINGS_MAX_HEADER_LIST_SIZE` on new streams.
maxCallDepth | 10 | If set, limits the number of maximum hops. The number of calls is derived by inspecting the Via header. This can be used to prevent proxy loops.

## HTTP/2 Service Parameters

Expand Down
1 change: 1 addition & 0 deletions linkerd/docs/protocol-http.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Key | Default Value | Description
--- | ------------- | -----------
addForwardedHeader | null | If set, a `Forwarded` header is added to all requests. See [below](#http-1-1-forwarded).
timestampHeader | null | If set, the specified header will be added to outbound requests with a timestamp. See [below](#http-1-1-timestamp).
maxCallDepth | 1000 | If set, limits the number of maximum hops. The number of hops is derived by inspecting the Via header. This can be used to prevent proxy loops.

<a name="http-1-1-timestamp"></a>
### Adding Timestamp Headers ###
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,4 +544,38 @@ class H2EndToEndTest extends FunSuite {
await(router.close())
}
}


test("returns 400 for requests that have more than allowed hops") {
val config =
s"""|routers:
|- protocol: h2
| servers:
| - port: 0
| maxCallDepth: 2
|""".stripMargin

val linker = Linker.Initializers(Seq(H2Initializer)).load(config)
val router = linker.routers.head.initialize()
val server = router.servers.head.serve()
val client = Upstream.mk(server)

val req = Request(Headers(
Headers.Via -> "hop1, hop2, hop3",
Headers.Authority -> "dog",
Headers.Path -> "/",
Headers.Method -> "get",
LinkerdHeaders.Ctx.Dtab.UserKey -> "/foo=>/bar"
), Stream.empty())

val rsp = await(client(req))
val expectedMessage = "Maximum number of calls (2) has been exceeded. Please check for proxy loops."
assert(rsp.status == Status.BadRequest)
assert(await(rsp.stream.readDataString) == expectedMessage)

await(client.close())
await(server.close())
await(router.close())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import io.buoyant.config.PolymorphicConfig
import io.buoyant.linkerd.protocol.h2._
import io.buoyant.router.h2.ClassifiedRetries.{BufferSize, ClassificationTimeout}
import io.buoyant.router.h2.{ClassifiedRetryFilter, DupRequest, H2AddForwardedHeader}
import io.buoyant.router.http.ForwardClientCertFilter
import io.buoyant.router.http.{ForwardClientCertFilter, MaxCallDepthFilter}
import io.buoyant.router.{ClassifiedRetries, H2, RoutingFactory}
import io.netty.handler.ssl.ApplicationProtocolNames
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -256,6 +256,7 @@ class H2ServerConfig extends ServerConfig with H2EndpointConfig {

var maxConcurrentStreamsPerConnection: Option[Int] = None
var addForwardedHeader: Option[AddForwardedHeaderConfig] = None
var maxCallDepth: Option[Int] = None

@JsonIgnore
override val alpnProtocols: Option[Seq[String]] =
Expand All @@ -269,7 +270,7 @@ class H2ServerConfig extends ServerConfig with H2EndpointConfig {

@JsonIgnore
override def serverParams = withEndpointParams(super.serverParams
+ AddForwardedHeaderConfig.Param(addForwardedHeader))
+ AddForwardedHeaderConfig.Param(addForwardedHeader)).maybeWith(maxCallDepth.map(MaxCallDepthFilter.Param(_)))
}

abstract class H2IdentifierConfig extends PolymorphicConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.twitter.util.Future
import io.buoyant.linkerd.protocol.h2.ErrorReseter.H2ResponseException
import io.buoyant.router.RoutingFactory
import io.buoyant.router.RoutingFactory.ResponseException
import io.buoyant.router.http.MaxCallDepthFilter.MaxCallDepthExceeded

/**
* Coerces routing failures to the appropriate HTTP/2 error code
Expand All @@ -29,6 +30,8 @@ class ErrorReseter extends SimpleFilter[Request, Response] {
Future.value(LinkerdHeaders.Err.respond(e.exceptionMessage(), Status.BadGateway))
case e: RichConnectionFailedExceptionWithPath =>
Future.value(LinkerdHeaders.Err.respond(e.exceptionMessage, Status.BadGateway))
case e: MaxCallDepthExceeded =>
Future.value(LinkerdHeaders.Err.respond(e.getMessage, Status.BadRequest))
case H2ResponseException(rsp) =>
Future.value(rsp)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import com.twitter.finagle.http.{param => _, _}
import com.twitter.finagle.service.ExpiringService
import com.twitter.finagle.stats.{InMemoryStatsReceiver, NullStatsReceiver}
import com.twitter.finagle.tracing.{Annotation, BufferingTracer, NullTracer}
import com.twitter.io.{Buf, Pipe, Reader}
import com.twitter.io.{Buf, Pipe}
import com.twitter.util._
import io.buoyant.router.StackRouter.Client.PerClientParams
import io.buoyant.test.{Awaits, BudgetedRetries}
Expand Down Expand Up @@ -938,6 +938,36 @@ class HttpEndToEndTest
assert(resp.contentString.contains(responseDiscardedMsg))
}



test("returns 400 for requests that have more than allowed hops", Retryable) {
val yaml =
s"""|routers:
|- protocol: http
| servers:
| - port: 0
| maxCallDepth: 2
|""".stripMargin
val linker = Linker.load(yaml)
val router = linker.routers.head.initialize()
val s = router.servers.head.serve()

val req = Request()
req.headerMap.add(Fields.Via, "hop1, hop2, hop3")

val c = upstream(s)
try {
val resp = await(c(req))
resp.status must be (Status.BadRequest)
resp.contentString must be ("Maximum number of calls (2) has been exceeded. Please check for proxy loops.")

} finally {
await(c.close())
await(s.close())
}
}


def idleTimeMsBaseTest(config:String)(assertionsF: (Router.Initialized, InMemoryStatsReceiver, Int) => Unit): Unit = {
// Arrange
val stats = new InMemoryStatsReceiver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import com.twitter.finagle.buoyant.{ParamsMaybeWith, PathMatcher}
import com.twitter.finagle.client.{AddrMetadataExtraction, StackClient}
import com.twitter.finagle.filter.DtabStatsFilter
import com.twitter.finagle.http.filter.{ClientDtabContextFilter, ServerDtabContextFilter, StatsFilter}
import com.twitter.finagle.http.{Request, Response, param => hparam}
import com.twitter.finagle.http.{Fields, HeaderMap, Request, Response, param => hparam}
import com.twitter.finagle.liveness.FailureAccrualFactory
import com.twitter.finagle.service.Retries
import com.twitter.finagle.stack.nilStack
Expand All @@ -20,9 +20,10 @@ import com.twitter.finagle.{ServiceFactory, Stack}
import com.twitter.logging.Policy
import io.buoyant.linkerd.protocol.HttpRequestAuthorizerConfig.param
import io.buoyant.linkerd.protocol.http._
import io.buoyant.router.http.{AddForwardedHeader, ForwardClientCertFilter, TimestampHeaderFilter}
import io.buoyant.router.{ClassifiedRetries, Http, RoutingFactory}
import scala.collection.JavaConverters._
import io.buoyant.router.http._
import io.buoyant.router.HttpInstances._

class HttpInitializer extends ProtocolInitializer.Simple {
val name = "http"
Expand Down Expand Up @@ -81,6 +82,7 @@ class HttpInitializer extends ProtocolInitializer.Simple {
// ensure the server-stack framing filter is placed below the stats filter
// so that any malframed requests it fails are counted as errors
.insertAfter(StatsFilter.role, FramingFilter.serverModule)
.insertAfter(FramingFilter.role, MaxCallDepthFilter.module[Request, HeaderMap, Response](Fields.Via))
.insertBefore(AddForwardedHeader.module.role, AddForwardedHeaderConfig.module[Request, Response])
.remove(ServerDtabContextFilter.role)

Expand Down Expand Up @@ -179,12 +181,13 @@ trait HttpSvcConfig extends SvcConfig {

case class HttpServerConfig(
addForwardedHeader: Option[AddForwardedHeaderConfig],
timestampHeader: Option[String]
timestampHeader: Option[String],
maxCallDepth: Option[Int]
) extends ServerConfig {

@JsonIgnore
override def serverParams = {
super.serverParams +
super.serverParams.maybeWith(maxCallDepth.map(x => MaxCallDepthFilter.Param(x))) +
AddForwardedHeaderConfig.Param(addForwardedHeader) +
TimestampHeaderFilter.Param(timestampHeader)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.twitter.finagle.service.RetryPolicy.RetryableWriteException
import com.twitter.logging.{Level, Logger}
import io.buoyant.router.RoutingFactory
import io.buoyant.router.RoutingFactory.ResponseException
import io.buoyant.router.http.MaxCallDepthFilter
import scala.util.control.NonFatal

class ErrorResponder(maxHeaderSize: Int)
Expand All @@ -29,6 +30,8 @@ class ErrorResponder(maxHeaderSize: Int)
Headers.Err.respond(e.exceptionMessage(), Status.BadGateway, maxHeaderSize)
case e: RichConnectionFailedExceptionWithPath =>
Headers.Err.respond(e.exceptionMessage, Status.BadGateway, maxHeaderSize)
case e: MaxCallDepthFilter.MaxCallDepthExceeded =>
Headers.Err.respond(e.getMessage, Status.BadRequest, maxHeaderSize)
case _ =>
val message = e.getMessage match {
case null => e.getClass.getName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class HttpInitializerTest extends FunSuite with Awaits with Eventually {
.configured(maxHeaderSize).configured(maxInitLineSize)
.configured(maxReqSize).configured(maxRspSize)
.configured(streaming).configured(compression)
.serving(HttpServerConfig(None, None).mk(HttpInitializer, "yolo"))
.serving(HttpServerConfig(None, None, None).mk(HttpInitializer, "yolo"))
.initialize()
assert(router.servers.size == 1)
val sparams = router.servers.head.params
Expand Down
1 change: 1 addition & 0 deletions project/LinkerdBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ object LinkerdBuild extends Base {

val baseHttp = projectDir("router/base-http")
.dependsOn(core)
.withTests()

val h2 = projectDir("router/h2")
.dependsOn(baseHttp, Finagle.h2 % "compile->compile;test->test")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.buoyant.router.http

import com.twitter.finagle._
import com.twitter.finagle.http.Fields.Via
import com.twitter.finagle.http.{Request, Response}
import com.twitter.util.Future
import io.buoyant.router.http.ForwardClientCertFilter.Enabled
import io.buoyant.router.http.MaxCallDepthFilter.MaxCallDepthExceeded
import scala.util.control.NoStackTrace

class MaxCallDepthFilter[Req, H: HeadersLike, Rep](maxCalls: Int, headerKey: String)
(implicit requestLike: RequestLike[Req, H]) extends SimpleFilter[Req, Rep] {

def numCalls(viaValue: String) = viaValue.split(",").length

def apply(req: Req, svc: Service[Req, Rep]): Future[Rep] = {
val headersLike = implicitly[HeadersLike[H]]

headersLike.get(requestLike.headers(req), headerKey) match {
case Some(v) if numCalls(v) > maxCalls => Future.exception(MaxCallDepthExceeded(maxCalls))
case _ => svc(req)
}
}
}

object MaxCallDepthFilter {

final case class MaxCallDepthExceeded(calls: Int)
extends Exception(s"Maximum number of calls ($calls) has been exceeded. Please check for proxy loops.")
with NoStackTrace

final case class Param(value: Int) extends AnyVal {
def mk(): (Param, Stack.Param[Param]) = (this, Param.param)
}

object Param {
implicit val param = Stack.Param(Param(1000))
}

def module[Req, H: HeadersLike, Rep](headerKey: String)
(implicit requestLike: RequestLike[Req, H]): Stackable[ServiceFactory[Req, Rep]] =
new Stack.Module1[Param, ServiceFactory[Req, Rep]] {
val role = Stack.Role("MaxCallDepthFilter")
val description = "Limits the number of hops by looking at the Via header of a request"

def make(param: Param, next: ServiceFactory[Req, Rep]) =
new MaxCallDepthFilter(param.value, headerKey) andThen next

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.buoyant.router.http

import com.twitter.finagle.Service
import com.twitter.finagle.http.Fields.Via
import com.twitter.finagle.http.{HeaderMap, Request, Response, Status}
import com.twitter.util.Future
import io.buoyant.test.{Awaits, FunSuite}


class MaxCallDepthFilterTest extends FunSuite with Awaits {

implicit object HttpHeadersLike extends HeadersLike[HeaderMap] {
override def toSeq(headers: HeaderMap): Seq[(String, String)] = ???
override def contains(headers: HeaderMap, k: String): Boolean = ???
override def get(headers: HeaderMap, k: String): Option[String] = headers.get(k)
override def getAll(headers: HeaderMap, k: String): Seq[String] = ???
override def add(headers: HeaderMap, k: String, v: String): Unit = ???
override def set(headers: HeaderMap, k: String, v: String): Unit = ???
override def remove(headers: HeaderMap, key: String): Seq[String] = ???
}

implicit object HttpRequestLike extends RequestLike[Request, HeaderMap] {
override def headers(request: Request): HeaderMap = request.headerMap
}


def service(maxCallDepth: Int) = new MaxCallDepthFilter[Request, HeaderMap, Response](
maxCallDepth,
Via
).andThen(Service.mk[Request, Response](_ => Future.value(Response())))

test("passes through requests not exceeding max hops") {
val viaHeader = (1 to 10).map(v => s"hop $v").mkString(", ")
val req = Request()
req.headerMap.add(Via, viaHeader)
assert(await(service(10)(req)).status == Status.Ok)
}

test("stops requests exceeding max hops") {
val expectedMessage = "Maximum number of calls (9) has been exceeded. Please check for proxy loops."
val viaHeader = (1 to 10).map(v => s"hop $v").mkString(", ")
val req = Request()
req.headerMap.add(Via, viaHeader)

val exception = intercept[MaxCallDepthFilter.MaxCallDepthExceeded] {
await(service(9)(req))
}
assert(exception.getMessage == expectedMessage)
}


}
3 changes: 2 additions & 1 deletion router/h2/src/main/scala/io/buoyant/router/H2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import com.twitter.finagle.service.StatsFilter
import io.buoyant.router.context.ResponseClassifierCtx
import io.buoyant.router.context.h2.H2ClassifierCtx
import io.buoyant.router.h2.{ClassifiedRetries => H2ClassifiedRetries, _}
import io.buoyant.router.http.ForwardClientCertFilter
import io.buoyant.router.http.{ForwardClientCertFilter, MaxCallDepthFilter}
import io.buoyant.router.H2Instances._

object H2 extends Router[Request, Response]
Expand Down Expand Up @@ -91,6 +91,7 @@ object H2 extends Router[Request, Response]
val newStack: Stack[ServiceFactory[Request, Response]] = FinagleH2.Server.newStack
.insertAfter(StackServer.Role.protoTracing, h2.ProxyRewriteFilter.module)
.insertAfter(StackServer.Role.protoTracing, H2AddForwardedHeader.module)
.insertAfter(StackServer.Role.protoTracing, MaxCallDepthFilter.module[Request, Headers, Response](Headers.Via))
.replace(StatsFilter.role, StreamStatsFilter.module)

private val serverResponseClassifier =
Expand Down

0 comments on commit d87f193

Please sign in to comment.