Skip to content

Commit

Permalink
DefaultWebSessionManager decoupled from DefaultWebSession
Browse files Browse the repository at this point in the history
DefaultWebSessionManager no longer requires the WebSessionStore
to use DefaultWebSession.

Removed explicit start() in save(). This seemed unnecessary since at
that point isStarted is guaranteed to return true. The status can
be updated through the copy constructor.

DefaultWebSessionTests added.

Issue: SPR-15875
  • Loading branch information
rwinch authored and rstoyanchev committed Sep 6, 2017
1 parent 8691247 commit 8ad14ae
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,9 @@ class DefaultWebSession implements WebSession {
}

/**
* Constructor to refresh an existing session for a new request.
* @param existingSession the session to recreate
* @param lastAccessTime the last access time
* @param saveOperation save operation for the current request
*/
DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime,
Function<WebSession, Mono<Void>> saveOperation) {

this.id = existingSession.id;
this.idGenerator = existingSession.idGenerator;
this.attributes = existingSession.attributes;
this.clock = existingSession.clock;
this.changeIdOperation = existingSession.changeIdOperation;
this.saveOperation = saveOperation;
this.creationTime = existingSession.creationTime;
this.lastAccessTime = lastAccessTime;
this.maxIdleTime = existingSession.maxIdleTime;
this.state = existingSession.state;
}

/**
* For testing purposes.
* Constructor for creating a new session with an updated last access time.
* @param existingSession the existing session to copy
* @param lastAccessTime the new last access time
*/
DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime) {
this.id = existingSession.id;
Expand All @@ -119,7 +100,7 @@ class DefaultWebSession implements WebSession {
this.creationTime = existingSession.creationTime;
this.lastAccessTime = lastAccessTime;
this.maxIdleTime = existingSession.maxIdleTime;
this.state = existingSession.state;
this.state = existingSession.isStarted() ? State.STARTED : State.NEW;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
*/
package org.springframework.web.server.session;

import java.time.Clock;
import java.time.Instant;
import java.time.ZoneId;
import java.util.List;

import reactor.core.publisher.Flux;
Expand All @@ -43,9 +40,6 @@ public class DefaultWebSessionManager implements WebSessionManager {

private WebSessionStore sessionStore = new InMemoryWebSessionStore();

private Clock clock = Clock.system(ZoneId.of("GMT"));


/**
* Configure the id resolution strategy.
* <p>By default an instance of {@link CookieWebSessionIdResolver}.
Expand Down Expand Up @@ -80,38 +74,14 @@ public WebSessionStore getSessionStore() {
return this.sessionStore;
}

/**
* Configure the {@link Clock} to use to set lastAccessTime on every created
* session and to calculate if it is expired.
* <p>This may be useful to align to different timezone or to set the clock
* back in a test, e.g. {@code Clock.offset(clock, Duration.ofMinutes(-31))}
* in order to simulate session expiration.
* <p>By default this is {@code Clock.system(ZoneId.of("GMT"))}.
* @param clock the clock to use
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "'clock' is required.");
this.clock = clock;
}

/**
* Return the configured clock for session lastAccessTime calculations.
*/
public Clock getClock() {
return this.clock;
}


@Override
public Mono<WebSession> getSession(ServerWebExchange exchange) {
return Mono.defer(() ->
retrieveSession(exchange)
.flatMap(session -> removeSessionIfExpired(exchange, session))
.flatMap(this.getSessionStore()::updateLastAccessTime)
.switchIfEmpty(createSession(exchange))
.cast(DefaultWebSession.class)
.map(session -> new DefaultWebSession(session, session.getLastAccessTime(), s -> saveSession(exchange, s)))
.doOnNext(session -> exchange.getResponse().beforeCommit(session::save)));
.switchIfEmpty(this.sessionStore.createWebSession())
.doOnNext(session -> exchange.getResponse().beforeCommit(() -> save(exchange, session))));
}

private Mono<WebSession> retrieveSession(ServerWebExchange exchange) {
Expand All @@ -128,35 +98,28 @@ private Mono<WebSession> removeSessionIfExpired(ServerWebExchange exchange, WebS
return Mono.just(session);
}

private Mono<Void> saveSession(ServerWebExchange exchange, WebSession session) {
private Mono<Void> save(ServerWebExchange exchange, WebSession session) {
if (session.isExpired()) {
return Mono.error(new IllegalStateException(
"Sessions are checked for expiration and have their " +
"lastAccessTime updated when first accessed during request processing. " +
"However this session is expired meaning that maxIdleTime elapsed " +
"before the call to session.save()."));
"lastAccessTime updated when first accessed during request processing. " +
"However this session is expired meaning that maxIdleTime elapsed " +
"before the call to session.save()."));
}

if (!session.isStarted()) {
return Mono.empty();
}

// Force explicit start
session.start();

if (hasNewSessionId(exchange, session)) {
this.sessionIdResolver.setSessionId(exchange, session.getId());
DefaultWebSessionManager.this.sessionIdResolver.setSessionId(exchange, session.getId());
}

return this.sessionStore.storeSession(session);
return session.save();
}

private boolean hasNewSessionId(ServerWebExchange exchange, WebSession session) {
List<String> ids = getSessionIdResolver().resolveSessionIds(exchange);
return ids.isEmpty() || !session.getId().equals(ids.get(0));
}

private Mono<WebSession> createSession(ServerWebExchange exchange) {
return this.sessionStore.createWebSession();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
*/
package org.springframework.web.server.session;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.Arrays;
import java.util.Collections;

Expand All @@ -32,8 +28,6 @@
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.util.IdGenerator;
import org.springframework.util.JdkIdGenerator;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
Expand All @@ -42,11 +36,11 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;

/**
Expand All @@ -57,11 +51,6 @@
@RunWith(MockitoJUnitRunner.class)
public class DefaultWebSessionManagerTests {

private static final Clock CLOCK = Clock.system(ZoneId.of("GMT"));

private static final IdGenerator idGenerator = new JdkIdGenerator();


private DefaultWebSessionManager manager;

private ServerWebExchange exchange;
Expand All @@ -72,10 +61,23 @@ public class DefaultWebSessionManagerTests {
@Mock
private WebSessionStore store;

@Mock
private WebSession createSession;

@Mock
private WebSession retrieveSession;

@Mock
private WebSession updateSession;

@Before
public void setUp() throws Exception {
when(this.store.createWebSession()).thenReturn(Mono.just(createDefaultWebSession()));
when(this.store.updateLastAccessTime(any())).thenAnswer( invocation -> Mono.just(invocation.getArgument(0)));
when(this.store.createWebSession()).thenReturn(Mono.just(this.createSession));
when(this.store.updateLastAccessTime(any())).thenReturn(Mono.just(this.updateSession));
when(this.store.retrieveSession(any())).thenReturn(Mono.just(this.retrieveSession));
when(this.createSession.save()).thenReturn(Mono.empty());
when(this.updateSession.getId()).thenReturn("update-session-id");
when(this.retrieveSession.getId()).thenReturn("retrieve-session-id");

this.manager = new DefaultWebSessionManager();
this.manager.setSessionIdResolver(this.idResolver);
Expand All @@ -87,90 +89,71 @@ public void setUp() throws Exception {
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
}


@Test
public void getSessionWithoutStarting() throws Exception {
public void getSessionSaveWhenCreatedAndNotStartedThenNotSaved() throws Exception {
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList());
WebSession session = this.manager.getSession(this.exchange).block();
session.save().block();
this.exchange.getResponse().setComplete().block();

assertFalse(session.isStarted());
assertFalse(session.isExpired());
verify(this.store, never()).storeSession(any());
verifyZeroInteractions(this.retrieveSession, this.updateSession);
verify(this.createSession, never()).save();
verify(this.idResolver, never()).setSessionId(any(), any());
}

@Test
public void startSessionExplicitly() throws Exception {
public void getSessionSaveWhenCreatedAndStartedThenSavesAndSetsId() throws Exception {
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList());
when(this.store.storeSession(any())).thenReturn(Mono.empty());
WebSession session = this.manager.getSession(this.exchange).block();
session.start();
session.save().block();
when(this.createSession.isStarted()).thenReturn(true);
this.exchange.getResponse().setComplete().block();

String id = session.getId();
verify(this.store).createWebSession();
verify(this.store).storeSession(any());
verify(this.createSession).save();
verify(this.idResolver).setSessionId(any(), eq(id));
}

@Test
public void startSessionImplicitly() throws Exception {
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList());
when(this.store.storeSession(any())).thenReturn(Mono.empty());
WebSession session = this.manager.getSession(this.exchange).block();
session.getAttributes().put("foo", "bar");
session.save().block();

verify(this.store).createWebSession();
verify(this.idResolver).setSessionId(any(), any());
verify(this.store).storeSession(any());
}

@Test
public void exchangeWhenResponseSetCompleteThenSavesAndSetsId() throws Exception {
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList());
when(this.store.storeSession(any())).thenReturn(Mono.empty());
String id = this.createSession.getId();
WebSession session = this.manager.getSession(this.exchange).block();
String id = session.getId();
session.getAttributes().put("foo", "bar");
when(this.createSession.isStarted()).thenReturn(true);
this.exchange.getResponse().setComplete().block();

verify(this.idResolver).setSessionId(any(), eq(id));
verify(this.store).storeSession(any());
verify(this.createSession).save();
}

@Test
public void existingSession() throws Exception {
DefaultWebSession existing = createDefaultWebSession();
String id = existing.getId();
when(this.store.retrieveSession(id)).thenReturn(Mono.just(existing));
String id = this.updateSession.getId();
when(this.store.retrieveSession(id)).thenReturn(Mono.just(this.updateSession));
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(id));

WebSession actual = this.manager.getSession(this.exchange).block();
assertNotNull(actual);
assertEquals(existing.getId(), actual.getId());
assertEquals(id, actual.getId());
}

@Test
public void existingSessionIsExpired() throws Exception {
DefaultWebSession existing = createDefaultWebSession();
existing.start();
Instant lastAccessTime = Instant.now(CLOCK).minus(Duration.ofMinutes(31));
existing = new DefaultWebSession(existing, lastAccessTime, s -> Mono.empty());
when(this.store.retrieveSession(existing.getId())).thenReturn(Mono.just(existing));
when(this.store.removeSession(existing.getId())).thenReturn(Mono.empty());
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(existing.getId()));
String id = this.retrieveSession.getId();
when(this.retrieveSession.isExpired()).thenReturn(true);
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(id));
when(this.store.removeSession(any())).thenReturn(Mono.empty());

WebSession actual = this.manager.getSession(this.exchange).block();
assertNotSame(existing, actual);
verify(this.store).removeSession(existing.getId());
assertEquals(this.createSession.getId(), actual.getId());
verify(this.store).removeSession(id);
verify(this.idResolver).expireSession(any());
}

@Test
public void multipleSessionIds() throws Exception {
DefaultWebSession existing = createDefaultWebSession();
WebSession existing = this.updateSession;
String id = existing.getId();
when(this.store.retrieveSession(any())).thenReturn(Mono.empty());
when(this.store.retrieveSession(id)).thenReturn(Mono.just(existing));
Expand All @@ -180,8 +163,4 @@ public void multipleSessionIds() throws Exception {
assertNotNull(actual);
assertEquals(existing.getId(), actual.getId());
}

private DefaultWebSession createDefaultWebSession() {
return new DefaultWebSession(idGenerator, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty());
}
}
Loading

0 comments on commit 8ad14ae

Please sign in to comment.