package csbase.server.services.restservice.websocket;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.glassfish.grizzly.http.HttpRequestPacket;
import org.glassfish.grizzly.websockets.Broadcaster;
import org.glassfish.grizzly.websockets.DataFrame;
import org.glassfish.grizzly.websockets.Extension;
import org.glassfish.grizzly.websockets.HandShake;
import org.glassfish.grizzly.websockets.HandshakeException;
import org.glassfish.grizzly.websockets.OptimizedBroadcaster;
import org.glassfish.grizzly.websockets.ProtocolHandler;
import org.glassfish.grizzly.websockets.WebSocket;
import org.glassfish.grizzly.websockets.WebSocketApplication;
import org.glassfish.grizzly.websockets.WebSocketListener;

import csbase.server.Server;
import csbase.server.services.restservice.websocket.utils.WebSocketUtils;
import ibase.rest.api.authentication.v1.adapter.UnauthorizedException;

/**
 * Customize {@link WebSocketApplication} implementation, which contains
 * application specific properties and logic.
 */
public class CSBaseWebSocketApplication extends WebSocketApplication {

	// Logged in users connections
	protected ConcurrentHashMap<String, List<CSBaseWebSocket>> connections = new ConcurrentHashMap<>();

	// Initialize optimized broadcaster
	protected final Broadcaster broadcaster = new OptimizedBroadcaster();

	public CSBaseWebSocketApplication() {
		super();

		Server.logInfoMessage("Instantiating WebSocketApplication: " + this.getClass().getSimpleName());

		// Local environment log settings
		WebSocketUtils.InitializeGrizzlyLogger();

		// Ping connections every minute
		Executors.newScheduledThreadPool(1).scheduleAtFixedRate(
				(Runnable) () -> connections.values()
						.forEach(userWebSockets -> userWebSockets.forEach(ws -> ws.sendPing(new byte[]{}))),
				0, 60, TimeUnit.SECONDS);
	}

	@Override
	public WebSocket createSocket(ProtocolHandler handler, HttpRequestPacket requestPacket,
			WebSocketListener... listeners) {

		try {

			// Instantiate new connection
			CSBaseWebSocket ws = new CSBaseWebSocket(handler, requestPacket, listeners);

			// Add new connection to connections map
			connections.compute(ws.getUser().getLogin(), (user, userWebSockets) -> {
				Server.logInfoMessage("WebSocket connection with " + this.getClass().getSimpleName()
						+ " requested by user " + ws.getUser().getLogin() + " successfully created");

				if (userWebSockets == null) {
					userWebSockets = new ArrayList<>();
				}
				userWebSockets.add(ws);
				return userWebSockets;

			});

			return ws;

		} catch (UnauthorizedException e) {
			e.printStackTrace();
			Server.logSevereMessage("Authentication error with  " + this.getClass().getSimpleName(), e);
			return super.createSocket(handler, requestPacket, listeners);
		}
	}

	@Override
	public void onConnect(WebSocket socket) {
		super.onConnect(socket);
		if (!(socket instanceof CSBaseWebSocket)) {
			Server.logSevereMessage("Closing unauthorized WebSocket " + this.getClass().getSimpleName());
			socket.close(401, "Unauthorized");
      return;
		}

		Server.logInfoMessage(
				this.getClass().getSimpleName() + " onConnect called for user " + ((CSBaseWebSocket) socket).getUser());

	}

	@Override
	public void onMessage(WebSocket socket, String text) {
		super.onMessage(socket, text);
	}

	@Override
	public void onMessage(WebSocket socket, byte[] bytes) {
		super.onMessage(socket, bytes);
	}

	@Override
	public void onClose(WebSocket socket, DataFrame frame) {
		super.onClose(socket, frame);

		if (socket instanceof CSBaseWebSocket) {
			Server.logInfoMessage(this.getClass().getSimpleName() + " onClose called for user "
					+ ((CSBaseWebSocket) socket).getUser());

			// Remove socket from connections map
			CSBaseWebSocket ws = ((CSBaseWebSocket) socket);
			connections.compute(ws.getUser().getLogin(), (user, userWebSockets) -> {
				userWebSockets.removeIf(userWebSocket -> userWebSocket.getId().equals(ws.getId()));

				Server.logInfoMessage(this.getClass().getSimpleName() + " user " + ((CSBaseWebSocket) socket).getUser()
						+ " has now " + userWebSockets.size() + " connections");

				return userWebSockets.size() > 0 ? userWebSockets : null;
			});
		}
	}

	@Override
	public void onPing(WebSocket socket, byte[] bytes) {
		super.onPing(socket, bytes);
	}

	@Override
	public void onPong(WebSocket socket, byte[] bytes) {
		super.onPong(socket, bytes);
	}

	@Override
	protected boolean onError(WebSocket webSocket, Throwable t) {
		Server.logSevereMessage(this.getClass().getSimpleName() + " onError callback", t);
		return super.onError(webSocket, t);

	}

	@Override
	protected boolean add(WebSocket socket) {
		return super.add(socket);
	}

	@Override
	public boolean remove(WebSocket socket) {
		return super.remove(socket);
	}

	@Override
	protected Set<WebSocket> getWebSockets() {
		return super.getWebSockets();
	}

	@Override
	protected void handshake(HandShake handshake) throws HandshakeException {
		super.handshake(handshake);
	}

	@Override
	public List<String> getSupportedProtocols(List<String> subProtocol) {
		return super.getSupportedProtocols(subProtocol);
	}

	@Override
	public List<Extension> getSupportedExtensions() {
		return super.getSupportedExtensions();
	}

	@Override
	public void onExtensionNegotiation(List<Extension> extensions) {
		super.onExtensionNegotiation(extensions);
	}
}
