/*
 * Decompiled with CFR 0.152.
 */
package org.wildfly.security.ssl;

import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLException;
import org.wildfly.security._private.ElytronMessages;
import org.wildfly.security.ssl.MechanismDatabase;
import org.wildfly.security.ssl.SSLConnectionInformation;

final class SSLExplorer {
    private static final MechanismDatabase database = MechanismDatabase.getInstance();
    public static final int RECORD_HEADER_SIZE = 5;

    private SSLExplorer() {
    }

    public static int getRequiredSize(ByteBuffer source) {
        ByteBuffer input = source.duplicate();
        if (input.remaining() < 5) {
            throw new BufferUnderflowException();
        }
        byte firstByte = input.get();
        byte secondByte = input.get();
        byte thirdByte = input.get();
        if ((firstByte & 0x80) != 0 && thirdByte == 1) {
            return 5;
        }
        return ((input.get() & 0xFF) << 8 | input.get() & 0xFF) + 5;
    }

    public static int getRequiredSize(byte[] source, int offset, int length) throws IOException {
        ByteBuffer byteBuffer = ByteBuffer.wrap(source, offset, length).asReadOnlyBuffer();
        return SSLExplorer.getRequiredSize(byteBuffer);
    }

    public static SSLConnectionInformationImpl explore(ByteBuffer source) throws SSLException {
        ByteBuffer input = source.duplicate();
        if (input.remaining() < 5) {
            throw new BufferUnderflowException();
        }
        byte firstByte = input.get();
        byte secondByte = input.get();
        byte thirdByte = input.get();
        if ((firstByte & 0x80) != 0 && thirdByte == 1) {
            return SSLExplorer.exploreV2HelloRecord(input, firstByte, secondByte, thirdByte);
        }
        if (firstByte == 22) {
            return SSLExplorer.exploreTLSRecord(input, firstByte, secondByte, thirdByte);
        }
        throw ElytronMessages.log.notHandshakeRecord();
    }

    public static SSLConnectionInformationImpl explore(byte[] source, int offset, int length) throws IOException {
        ByteBuffer byteBuffer = ByteBuffer.wrap(source, offset, length).asReadOnlyBuffer();
        return SSLExplorer.explore(byteBuffer);
    }

    private static SSLConnectionInformationImpl exploreV2HelloRecord(ByteBuffer input, byte firstByte, byte secondByte, byte thirdByte) throws SSLException {
        try {
            if (thirdByte != 1) {
                throw ElytronMessages.log.unsupportedSslRecord();
            }
            byte helloVersionMajor = input.get();
            byte helloVersionMinor = input.get();
            input.getShort();
            input.getShort();
            ArrayList ciphers = new ArrayList();
            for (int csLen = SSLExplorer.getInt16(input); csLen >= 3; csLen -= 3) {
                MechanismDatabase.Entry entry;
                int lead = SSLExplorer.getInt8(input);
                int byte1 = SSLExplorer.getInt8(input);
                int byte2 = SSLExplorer.getInt8(input);
                if (lead != 0 || (entry = database.getCipherSuiteById(byte1, byte2)) == null) continue;
                ciphers.add(entry.getName());
            }
            return new SSLConnectionInformationImpl(0, 2, helloVersionMajor, helloVersionMinor, Collections.emptyList(), Collections.emptyList(), ciphers.isEmpty() ? Collections.emptyList() : ciphers);
        }
        catch (BufferUnderflowException ignored) {
            throw ElytronMessages.log.invalidHandshakeRecord();
        }
    }

    private static SSLConnectionInformationImpl exploreTLSRecord(ByteBuffer input, byte firstByte, byte secondByte, byte thirdByte) throws SSLException {
        if (firstByte != 22) {
            throw ElytronMessages.log.notHandshakeRecord();
        }
        int recordLength = SSLExplorer.getInt16(input);
        if (recordLength > input.remaining()) {
            throw new BufferUnderflowException();
        }
        try {
            return SSLExplorer.exploreHandshake(input, secondByte, thirdByte, recordLength);
        }
        catch (BufferUnderflowException ignored) {
            throw ElytronMessages.log.invalidHandshakeRecord();
        }
    }

    private static SSLConnectionInformationImpl exploreHandshake(ByteBuffer input, byte recordMajorVersion, byte recordMinorVersion, int recordLength) throws SSLException {
        byte handshakeType = input.get();
        if (handshakeType != 1) {
            throw ElytronMessages.log.expectedClientHello();
        }
        int handshakeLength = SSLExplorer.getInt24(input);
        if (handshakeLength > recordLength - 4) {
            throw ElytronMessages.log.multiRecordSSLHandshake();
        }
        input = input.duplicate();
        input.limit(handshakeLength + input.position());
        return SSLExplorer.exploreClientHello(input, recordMajorVersion, recordMinorVersion);
    }

    private static SSLConnectionInformationImpl exploreClientHello(ByteBuffer input, byte recordMajorVersion, byte recordMinorVersion) throws SSLException {
        ExtensionInfo info = null;
        byte helloMajorVersion = input.get();
        byte helloMinorVersion = input.get();
        int position = input.position();
        input.position(position + 32);
        SSLExplorer.ignoreByteVector8(input);
        ArrayList ciphers = new ArrayList();
        for (int csLen = SSLExplorer.getInt16(input); csLen > 0; csLen -= 2) {
            int byte2;
            int byte1 = SSLExplorer.getInt8(input);
            MechanismDatabase.Entry entry = database.getCipherSuiteById(byte1, byte2 = SSLExplorer.getInt8(input));
            if (entry == null) continue;
            ciphers.add(entry.getName());
        }
        SSLExplorer.ignoreByteVector8(input);
        if (input.remaining() > 0) {
            info = SSLExplorer.exploreExtensions(input);
        }
        List<SNIServerName> snList = info != null ? info.sni : Collections.emptyList();
        List<String> alpnProtocols = info != null ? info.alpn : Collections.emptyList();
        return new SSLConnectionInformationImpl(recordMajorVersion, recordMinorVersion, helloMajorVersion, helloMinorVersion, snList, alpnProtocols, ciphers.isEmpty() ? Collections.emptyList() : ciphers);
    }

    private static ExtensionInfo exploreExtensions(ByteBuffer input) throws SSLException {
        int extLen;
        List<SNIServerName> sni = Collections.emptyList();
        List<String> alpn = Collections.emptyList();
        for (int length = SSLExplorer.getInt16(input); length > 0; length -= extLen + 4) {
            int extType = SSLExplorer.getInt16(input);
            extLen = SSLExplorer.getInt16(input);
            if (extType == 0) {
                sni = SSLExplorer.exploreSNIExt(input, extLen);
                continue;
            }
            if (extType == 16) {
                alpn = SSLExplorer.exploreALPN(input, extLen);
                continue;
            }
            SSLExplorer.ignoreByteVector(input, extLen);
        }
        return new ExtensionInfo(sni, alpn);
    }

    private static List<String> exploreALPN(ByteBuffer input, int extLen) throws SSLException {
        ArrayList strings = new ArrayList();
        int rem = extLen;
        if (extLen >= 2) {
            int listLen = SSLExplorer.getInt16(input);
            if (listLen == 0 || listLen + 2 != extLen) {
                throw ElytronMessages.log.invalidTlsExt();
            }
            rem -= 2;
            while (rem > 0) {
                int len = SSLExplorer.getInt8(input);
                if (len > rem) {
                    throw ElytronMessages.log.notEnoughData();
                }
                byte[] b = new byte[len];
                input.get(b);
                strings.add(new String(b, StandardCharsets.UTF_8));
                rem -= len + 1;
            }
        }
        return strings.isEmpty() ? Collections.emptyList() : strings;
    }

    private static List<SNIServerName> exploreSNIExt(ByteBuffer input, int extLen) throws SSLException {
        LinkedHashMap<Integer, SNIServerName> sniMap = new LinkedHashMap<Integer, SNIServerName>();
        int remains = extLen;
        if (extLen >= 2) {
            int listLen = SSLExplorer.getInt16(input);
            if (listLen == 0 || listLen + 2 != extLen) {
                throw ElytronMessages.log.invalidTlsExt();
            }
            remains -= 2;
            while (remains > 0) {
                SNIServerName serverName;
                int code = SSLExplorer.getInt8(input);
                int snLen = SSLExplorer.getInt16(input);
                if (snLen > remains) {
                    throw ElytronMessages.log.notEnoughData();
                }
                byte[] encoded = new byte[snLen];
                input.get(encoded);
                switch (code) {
                    case 0: {
                        if (encoded.length == 0) {
                            throw ElytronMessages.log.emptyHostNameSni();
                        }
                        serverName = new SNIHostName(encoded);
                        break;
                    }
                    default: {
                        serverName = new UnknownServerName(code, encoded);
                    }
                }
                if (sniMap.put(serverName.getType(), serverName) != null) {
                    throw ElytronMessages.log.duplicatedSniServerName(serverName.getType());
                }
                remains -= encoded.length + 3;
            }
        } else if (extLen == 0) {
            throw ElytronMessages.log.invalidTlsExt();
        }
        if (remains != 0) {
            throw ElytronMessages.log.invalidTlsExt();
        }
        return Collections.unmodifiableList(new ArrayList(sniMap.values()));
    }

    private static int getInt8(ByteBuffer input) {
        return input.get();
    }

    private static int getInt16(ByteBuffer input) {
        return (input.get() & 0xFF) << 8 | input.get() & 0xFF;
    }

    private static int getInt24(ByteBuffer input) {
        return (input.get() & 0xFF) << 16 | (input.get() & 0xFF) << 8 | input.get() & 0xFF;
    }

    private static void ignoreByteVector8(ByteBuffer input) {
        SSLExplorer.ignoreByteVector(input, SSLExplorer.getInt8(input));
    }

    private static void ignoreByteVector16(ByteBuffer input) {
        SSLExplorer.ignoreByteVector(input, SSLExplorer.getInt16(input));
    }

    private static void ignoreByteVector24(ByteBuffer input) {
        SSLExplorer.ignoreByteVector(input, SSLExplorer.getInt24(input));
    }

    private static void ignoreByteVector(ByteBuffer input, int length) {
        if (length != 0) {
            int position = input.position();
            input.position(position + length);
        }
    }

    static final class SSLConnectionInformationImpl
    implements SSLConnectionInformation {
        private final String recordVersion;
        private final String helloVersion;
        private final List<SNIServerName> sniNames;
        private final List<String> alpnProtocols;
        private final List<String> ciphers;

        SSLConnectionInformationImpl(byte recordMajorVersion, byte recordMinorVersion, byte helloMajorVersion, byte helloMinorVersion, List<SNIServerName> sniNames, List<String> alpnProtocols, List<String> ciphers) {
            this.recordVersion = SSLConnectionInformationImpl.getVersionString(recordMajorVersion, recordMinorVersion);
            this.helloVersion = SSLConnectionInformationImpl.getVersionString(helloMajorVersion, helloMinorVersion);
            this.sniNames = sniNames;
            this.alpnProtocols = alpnProtocols;
            this.ciphers = ciphers;
        }

        private static String getVersionString(byte helloMajorVersion, byte helloMinorVersion) {
            switch (helloMajorVersion) {
                case 0: {
                    switch (helloMinorVersion) {
                        case 2: {
                            return "SSLv2Hello";
                        }
                    }
                    return SSLConnectionInformationImpl.unknownVersion(helloMajorVersion, helloMinorVersion);
                }
                case 3: {
                    switch (helloMinorVersion) {
                        case 0: {
                            return "SSLv3";
                        }
                        case 1: {
                            return "TLSv1";
                        }
                        case 2: {
                            return "TLSv1.1";
                        }
                        case 3: {
                            return "TLSv1.2";
                        }
                        case 4: {
                            return "TLSv1.3";
                        }
                    }
                    return SSLConnectionInformationImpl.unknownVersion(helloMajorVersion, helloMinorVersion);
                }
            }
            return SSLConnectionInformationImpl.unknownVersion(helloMajorVersion, helloMinorVersion);
        }

        @Override
        public String getRecordVersion() {
            return this.recordVersion;
        }

        @Override
        public String getHelloVersion() {
            return this.helloVersion;
        }

        @Override
        public List<SNIServerName> getSNIServerNames() {
            return Collections.unmodifiableList(this.sniNames);
        }

        @Override
        public List<String> getProtocols() {
            return Collections.unmodifiableList(this.alpnProtocols);
        }

        @Override
        public List<String> getCipherSuites() {
            return Collections.unmodifiableList(this.ciphers);
        }

        private static String unknownVersion(byte major, byte minor) {
            return "Unknown-" + (major & 0xFF) + "." + (minor & 0xFF);
        }
    }

    static final class ExtensionInfo {
        final List<SNIServerName> sni;
        final List<String> alpn;

        ExtensionInfo(List<SNIServerName> sni, List<String> alpn) {
            this.sni = sni;
            this.alpn = alpn;
        }
    }

    static final class UnknownServerName
    extends SNIServerName {
        UnknownServerName(int code, byte[] encoded) {
            super(code, encoded);
        }
    }
}

