package br.pucrio.tecgraf.soma.logsmonitor.websocket;

import br.pucrio.tecgraf.soma.logsmonitor.utils.ConstantsUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.MultiValueMap;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;

import java.util.List;
import java.util.Map;

public class WebSocketHandshakeAuthJwtInterceptor implements HandshakeInterceptor {

  private static final Log logger = LogFactory.getLog(WebSocketHandshakeAuthJwtInterceptor.class);

  @Override
  public boolean beforeHandshake(
      ServerHttpRequest request,
      ServerHttpResponse response,
      WebSocketHandler wsHandler,
      Map<String, Object> attributes)
      throws Exception {
    String accessToken = getAuthToken(request);

    if (accessToken == null || accessToken.isEmpty()) {
      response.setStatusCode(HttpStatus.UNAUTHORIZED);
      logger.error(
          String.format("There was an unauthorized request:\n", request.getHeaders().toString()));
      return false;
    }
    attributes.put(ConstantsUtils.WEBSOCKET_SESSION_ATTRIBUTES_TOKEN_KEY, accessToken);
    return true;
  }

  private String getAuthToken(ServerHttpRequest request) {
    final MultiValueMap<String, String> queryParams =
        UriComponentsBuilder.fromHttpRequest(request).build().getQueryParams();
    if (!queryParams.containsKey(ConstantsUtils.WEBSOCKET_HANDSHAKE_QUERY_PARAM_TOKEN_KEY)) {
      return null;
    }
    final List<String> tokenParamValues =
        queryParams.get(ConstantsUtils.WEBSOCKET_HANDSHAKE_QUERY_PARAM_TOKEN_KEY);
    if (tokenParamValues.isEmpty()) {
      return null;
    }
    return tokenParamValues.get(0);
  }

  @Override
  public void afterHandshake(
      ServerHttpRequest request,
      ServerHttpResponse response,
      WebSocketHandler wsHandler,
      Exception exception) {}
}
