/*
 * Decompiled with CFR 0.152.
 */
package com.mongodb.internal.connection;

import com.mongodb.assertions.Assertions;
import com.mongodb.connection.ProxySettings;
import com.mongodb.internal.connection.DomainNameUtils;
import com.mongodb.internal.connection.InetAddressUtils;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ConnectException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;

public final class SocksSocket
extends Socket {
    private static final byte SOCKS_VERSION = 5;
    private static final byte RESERVED = 0;
    private static final byte PORT_LENGTH = 2;
    private static final byte AUTHENTICATION_SUCCEEDED_STATUS = 0;
    public static final String IP_PARSING_ERROR_SUFFIX = " is not an IP string literal";
    private static final byte USER_PASSWORD_SUB_NEGOTIATION_VERSION = 1;
    private InetSocketAddress remoteAddress;
    private final ProxySettings proxySettings;
    @Nullable
    private final Socket socket;

    public SocksSocket(ProxySettings proxySettings) {
        this(null, proxySettings);
    }

    public SocksSocket(@Nullable Socket socket, ProxySettings proxySettings) {
        Assertions.assertNotNull(proxySettings.getHost());
        if (socket != null) {
            Assertions.assertFalse(socket.isConnected());
        }
        this.socket = socket;
        this.proxySettings = proxySettings;
    }

    @Override
    public void connect(SocketAddress endpoint, int timeoutMs) throws IOException {
        Assertions.isTrueArgument("timeoutMs", timeoutMs >= 0);
        try {
            Timeout timeout = Timeout.expiresIn(timeoutMs, TimeUnit.MILLISECONDS, Timeout.ZeroSemantics.ZERO_DURATION_MEANS_INFINITE);
            InetSocketAddress unresolvedAddress = (InetSocketAddress)endpoint;
            Assertions.assertTrue(unresolvedAddress.isUnresolved());
            this.remoteAddress = unresolvedAddress;
            InetSocketAddress proxyAddress = new InetSocketAddress(Assertions.assertNotNull(this.proxySettings.getHost()), this.proxySettings.getPort());
            timeout.checkedRun(TimeUnit.MILLISECONDS, () -> this.socketConnect(proxyAddress, 0), ms -> this.socketConnect(proxyAddress, Math.toIntExact(ms)), () -> SocksSocket.throwSocketConnectionTimeout());
            SocksAuthenticationMethod authenticationMethod = this.performNegotiation(timeout);
            this.authenticate(authenticationMethod, timeout);
            this.sendConnect(timeout);
        }
        catch (SocketException socketException) {
            try {
                this.close();
            }
            catch (Exception closeException) {
                socketException.addSuppressed(closeException);
            }
            throw socketException;
        }
    }

    private void socketConnect(InetSocketAddress proxyAddress, int rem) throws IOException {
        if (this.socket != null) {
            this.socket.connect(proxyAddress, rem);
        } else {
            super.connect(proxyAddress, rem);
        }
    }

    private void sendConnect(Timeout timeout) throws IOException {
        AddressType addressType;
        String host = this.remoteAddress.getHostName();
        int port = this.remoteAddress.getPort();
        byte[] bytesOfHost = host.getBytes(StandardCharsets.US_ASCII);
        int hostLength = bytesOfHost.length;
        byte[] ipAddress = null;
        if (DomainNameUtils.isDomainName(host)) {
            addressType = AddressType.DOMAIN_NAME;
        } else {
            ipAddress = SocksSocket.createByteArrayFromIpAddress(host);
            addressType = this.determineAddressType(ipAddress);
        }
        byte[] bufferSent = SocksSocket.createBuffer(addressType, hostLength);
        bufferSent[0] = 5;
        bufferSent[1] = SocksCommand.CONNECT.getCommandNumber();
        bufferSent[2] = 0;
        switch (addressType) {
            case DOMAIN_NAME: {
                bufferSent[3] = AddressType.DOMAIN_NAME.getAddressTypeNumber();
                bufferSent[4] = (byte)hostLength;
                System.arraycopy(bytesOfHost, 0, bufferSent, 5, hostLength);
                SocksSocket.addPort(bufferSent, 5 + hostLength, port);
                break;
            }
            case IP_V4: {
                bufferSent[3] = AddressType.IP_V4.getAddressTypeNumber();
                System.arraycopy(ipAddress, 0, bufferSent, 4, ipAddress.length);
                SocksSocket.addPort(bufferSent, 4 + ipAddress.length, port);
                break;
            }
            case IP_V6: {
                bufferSent[3] = AddressType.DOMAIN_NAME.getAddressTypeNumber();
                System.arraycopy(ipAddress, 0, bufferSent, 4, ipAddress.length);
                SocksSocket.addPort(bufferSent, 4 + ipAddress.length, port);
                break;
            }
            default: {
                Assertions.fail();
            }
        }
        OutputStream outputStream = this.getOutputStream();
        outputStream.write(bufferSent);
        outputStream.flush();
        this.checkServerReply(timeout);
    }

    private static void addPort(byte[] bufferSent, int index, int port) {
        bufferSent[index] = (byte)(port >> 8);
        bufferSent[index + 1] = (byte)port;
    }

    private static byte[] createByteArrayFromIpAddress(String host) throws SocketException {
        byte[] bytes = InetAddressUtils.ipStringToBytes(host);
        if (bytes == null) {
            throw new SocketException(host + IP_PARSING_ERROR_SUFFIX);
        }
        return bytes;
    }

    private AddressType determineAddressType(byte[] ipAddress) {
        if (ipAddress.length == AddressType.IP_V4.getLength()) {
            return AddressType.IP_V4;
        }
        if (ipAddress.length == AddressType.IP_V6.getLength()) {
            return AddressType.IP_V6;
        }
        throw Assertions.fail();
    }

    private static byte[] createBuffer(AddressType addressType, int hostLength) {
        switch (addressType) {
            case DOMAIN_NAME: {
                return new byte[7 + hostLength];
            }
            case IP_V4: {
                return new byte[6 + AddressType.IP_V4.getLength()];
            }
            case IP_V6: {
                return new byte[6 + AddressType.IP_V6.getLength()];
            }
        }
        throw Assertions.fail();
    }

    private void checkServerReply(Timeout timeout) throws IOException {
        byte[] data = this.readSocksReply(4, timeout);
        ServerReply reply = ServerReply.of(data[1]);
        if (reply == ServerReply.REPLY_SUCCEEDED) {
            switch (AddressType.of(data[3])) {
                case DOMAIN_NAME: {
                    byte hostNameLength = this.readSocksReply(1, timeout)[0];
                    this.readSocksReply(hostNameLength + 2, timeout);
                    break;
                }
                case IP_V4: {
                    this.readSocksReply(AddressType.IP_V4.getLength() + 2, timeout);
                    break;
                }
                case IP_V6: {
                    this.readSocksReply(AddressType.IP_V6.getLength() + 2, timeout);
                    break;
                }
                default: {
                    throw Assertions.fail();
                }
            }
            return;
        }
        throw new ConnectException(reply.getMessage());
    }

    private void authenticate(SocksAuthenticationMethod authenticationMethod, Timeout timeout) throws IOException {
        if (authenticationMethod == SocksAuthenticationMethod.USERNAME_PASSWORD) {
            byte[] bytesOfUsername = Assertions.assertNotNull(this.proxySettings.getUsername()).getBytes(StandardCharsets.UTF_8);
            byte[] bytesOfPassword = Assertions.assertNotNull(this.proxySettings.getPassword()).getBytes(StandardCharsets.UTF_8);
            int usernameLength = bytesOfUsername.length;
            int passwordLength = bytesOfPassword.length;
            byte[] command = new byte[3 + usernameLength + passwordLength];
            command[0] = 1;
            command[1] = (byte)usernameLength;
            System.arraycopy(bytesOfUsername, 0, command, 2, usernameLength);
            command[2 + usernameLength] = (byte)passwordLength;
            System.arraycopy(bytesOfPassword, 0, command, 3 + usernameLength, passwordLength);
            OutputStream outputStream = this.getOutputStream();
            outputStream.write(command);
            outputStream.flush();
            byte[] authResult = this.readSocksReply(2, timeout);
            byte authStatus = authResult[1];
            if (authStatus != 0) {
                throw new ConnectException("Authentication failed. Proxy server returned status: " + authStatus);
            }
        }
    }

    private SocksAuthenticationMethod performNegotiation(Timeout timeout) throws IOException {
        SocksAuthenticationMethod[] authenticationMethods = this.getSocksAuthenticationMethods();
        int methodsCount = authenticationMethods.length;
        byte[] bufferSent = new byte[2 + methodsCount];
        bufferSent[0] = 5;
        bufferSent[1] = (byte)methodsCount;
        for (int i = 0; i < methodsCount; ++i) {
            bufferSent[2 + i] = authenticationMethods[i].getMethodNumber();
        }
        OutputStream outputStream = this.getOutputStream();
        outputStream.write(bufferSent);
        outputStream.flush();
        byte[] handshakeReply = this.readSocksReply(2, timeout);
        if (handshakeReply[0] != 5) {
            throw new ConnectException("Remote server doesn't support socks version 5 Received version: " + handshakeReply[0]);
        }
        byte authMethodNumber = handshakeReply[1];
        if (authMethodNumber == -1) {
            throw new ConnectException("None of the authentication methods listed are acceptable. Attempted methods: " + Arrays.toString((Object[])authenticationMethods));
        }
        if (authMethodNumber == SocksAuthenticationMethod.NO_AUTH.getMethodNumber()) {
            return SocksAuthenticationMethod.NO_AUTH;
        }
        if (authMethodNumber == SocksAuthenticationMethod.USERNAME_PASSWORD.getMethodNumber()) {
            return SocksAuthenticationMethod.USERNAME_PASSWORD;
        }
        throw new ConnectException("Proxy returned unsupported authentication method: " + authMethodNumber);
    }

    private SocksAuthenticationMethod[] getSocksAuthenticationMethods() {
        SocksAuthenticationMethod[] authMethods = this.proxySettings.getUsername() != null ? new SocksAuthenticationMethod[]{SocksAuthenticationMethod.NO_AUTH, SocksAuthenticationMethod.USERNAME_PASSWORD} : new SocksAuthenticationMethod[]{SocksAuthenticationMethod.NO_AUTH};
        return authMethods;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private byte[] readSocksReply(int length, Timeout timeout) throws IOException {
        InputStream inputStream = this.getInputStream();
        byte[] data = new byte[length];
        int originalTimeout = this.getSoTimeout();
        try {
            int count;
            for (int received = 0; received < length; received += count) {
                timeout.checkedRun(TimeUnit.MILLISECONDS, () -> this.setSoTimeout(0), remainingMs -> this.setSoTimeout(Math.toIntExact(remainingMs)), () -> SocksSocket.throwSocketConnectionTimeout());
                count = inputStream.read(data, received, length - received);
                if (count >= 0) continue;
                throw new ConnectException("Malformed reply from SOCKS proxy server");
            }
        }
        finally {
            this.setSoTimeout(originalTimeout);
        }
        return data;
    }

    private static void throwSocketConnectionTimeout() throws SocketTimeoutException {
        throw new SocketTimeoutException("Socket connection timed out");
    }

    @Override
    public void close() throws IOException {
        try (Socket autoClosed = this.socket;){
            super.close();
        }
    }

    @Override
    public void setSoTimeout(int timeout) throws SocketException {
        if (this.socket != null) {
            this.socket.setSoTimeout(timeout);
        } else {
            super.setSoTimeout(timeout);
        }
    }

    @Override
    public int getSoTimeout() throws SocketException {
        if (this.socket != null) {
            return this.socket.getSoTimeout();
        }
        return super.getSoTimeout();
    }

    @Override
    public void bind(SocketAddress bindpoint) throws IOException {
        if (this.socket != null) {
            this.socket.bind(bindpoint);
        } else {
            super.bind(bindpoint);
        }
    }

    @Override
    public InetAddress getInetAddress() {
        if (this.socket != null) {
            return this.socket.getInetAddress();
        }
        return super.getInetAddress();
    }

    @Override
    public InetAddress getLocalAddress() {
        if (this.socket != null) {
            return this.socket.getLocalAddress();
        }
        return super.getLocalAddress();
    }

    @Override
    public int getPort() {
        if (this.socket != null) {
            return this.socket.getPort();
        }
        return super.getPort();
    }

    @Override
    public int getLocalPort() {
        if (this.socket != null) {
            return this.socket.getLocalPort();
        }
        return super.getLocalPort();
    }

    @Override
    public SocketAddress getRemoteSocketAddress() {
        if (this.socket != null) {
            return this.socket.getRemoteSocketAddress();
        }
        return super.getRemoteSocketAddress();
    }

    @Override
    public SocketAddress getLocalSocketAddress() {
        if (this.socket != null) {
            return this.socket.getLocalSocketAddress();
        }
        return super.getLocalSocketAddress();
    }

    @Override
    public SocketChannel getChannel() {
        if (this.socket != null) {
            return this.socket.getChannel();
        }
        return super.getChannel();
    }

    @Override
    public void setTcpNoDelay(boolean on) throws SocketException {
        if (this.socket != null) {
            this.socket.setTcpNoDelay(on);
        } else {
            super.setTcpNoDelay(on);
        }
    }

    @Override
    public boolean getTcpNoDelay() throws SocketException {
        if (this.socket != null) {
            return this.socket.getTcpNoDelay();
        }
        return super.getTcpNoDelay();
    }

    @Override
    public void setSoLinger(boolean on, int linger) throws SocketException {
        if (this.socket != null) {
            this.socket.setSoLinger(on, linger);
        } else {
            super.setSoLinger(on, linger);
        }
    }

    @Override
    public int getSoLinger() throws SocketException {
        if (this.socket != null) {
            return this.socket.getSoLinger();
        }
        return super.getSoLinger();
    }

    @Override
    public void sendUrgentData(int data) throws IOException {
        if (this.socket != null) {
            this.socket.sendUrgentData(data);
        } else {
            super.sendUrgentData(data);
        }
    }

    @Override
    public void setOOBInline(boolean on) throws SocketException {
        if (this.socket != null) {
            this.socket.setOOBInline(on);
        } else {
            super.setOOBInline(on);
        }
    }

    @Override
    public boolean getOOBInline() throws SocketException {
        if (this.socket != null) {
            return this.socket.getOOBInline();
        }
        return super.getOOBInline();
    }

    @Override
    public void setSendBufferSize(int size) throws SocketException {
        if (this.socket != null) {
            this.socket.setSendBufferSize(size);
        } else {
            super.setSendBufferSize(size);
        }
    }

    @Override
    public int getSendBufferSize() throws SocketException {
        if (this.socket != null) {
            return this.socket.getSendBufferSize();
        }
        return super.getSendBufferSize();
    }

    @Override
    public void setReceiveBufferSize(int size) throws SocketException {
        if (this.socket != null) {
            this.socket.setReceiveBufferSize(size);
        } else {
            super.setReceiveBufferSize(size);
        }
    }

    @Override
    public int getReceiveBufferSize() throws SocketException {
        if (this.socket != null) {
            return this.socket.getReceiveBufferSize();
        }
        return super.getReceiveBufferSize();
    }

    @Override
    public void setKeepAlive(boolean on) throws SocketException {
        if (this.socket != null) {
            this.socket.setKeepAlive(on);
        } else {
            super.setKeepAlive(on);
        }
    }

    @Override
    public boolean getKeepAlive() throws SocketException {
        if (this.socket != null) {
            return this.socket.getKeepAlive();
        }
        return super.getKeepAlive();
    }

    @Override
    public void setTrafficClass(int tc) throws SocketException {
        if (this.socket != null) {
            this.socket.setTrafficClass(tc);
        } else {
            super.setTrafficClass(tc);
        }
    }

    @Override
    public int getTrafficClass() throws SocketException {
        if (this.socket != null) {
            return this.socket.getTrafficClass();
        }
        return super.getTrafficClass();
    }

    @Override
    public void setReuseAddress(boolean on) throws SocketException {
        if (this.socket != null) {
            this.socket.setReuseAddress(on);
        } else {
            super.setReuseAddress(on);
        }
    }

    @Override
    public boolean getReuseAddress() throws SocketException {
        if (this.socket != null) {
            return this.socket.getReuseAddress();
        }
        return super.getReuseAddress();
    }

    @Override
    public void shutdownInput() throws IOException {
        if (this.socket != null) {
            this.socket.shutdownInput();
        } else {
            super.shutdownInput();
        }
    }

    @Override
    public void shutdownOutput() throws IOException {
        if (this.socket != null) {
            this.socket.shutdownOutput();
        } else {
            super.shutdownOutput();
        }
    }

    @Override
    public boolean isConnected() {
        if (this.socket != null) {
            return this.socket.isConnected();
        }
        return super.isConnected();
    }

    @Override
    public boolean isBound() {
        if (this.socket != null) {
            return this.socket.isBound();
        }
        return super.isBound();
    }

    @Override
    public boolean isClosed() {
        if (this.socket != null) {
            return this.socket.isClosed();
        }
        return super.isClosed();
    }

    @Override
    public boolean isInputShutdown() {
        if (this.socket != null) {
            return this.socket.isInputShutdown();
        }
        return super.isInputShutdown();
    }

    @Override
    public boolean isOutputShutdown() {
        if (this.socket != null) {
            return this.socket.isOutputShutdown();
        }
        return super.isOutputShutdown();
    }

    @Override
    public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
        if (this.socket != null) {
            this.socket.setPerformancePreferences(connectionTime, latency, bandwidth);
        } else {
            super.setPerformancePreferences(connectionTime, latency, bandwidth);
        }
    }

    @Override
    public InputStream getInputStream() throws IOException {
        if (this.socket != null) {
            return this.socket.getInputStream();
        }
        return super.getInputStream();
    }

    @Override
    public OutputStream getOutputStream() throws IOException {
        if (this.socket != null) {
            return this.socket.getOutputStream();
        }
        return super.getOutputStream();
    }

    private static enum SocksAuthenticationMethod {
        NO_AUTH(0),
        USERNAME_PASSWORD(2);

        private final byte methodNumber;

        private SocksAuthenticationMethod(int methodNumber) {
            this.methodNumber = (byte)methodNumber;
        }

        public byte getMethodNumber() {
            return this.methodNumber;
        }
    }

    static enum AddressType {
        IP_V4(1, 4),
        IP_V6(4, 16),
        DOMAIN_NAME(3, -1){

            @Override
            public byte getLength() {
                throw Assertions.fail();
            }
        };

        private final byte length;
        private final byte addressTypeNumber;

        private AddressType(int addressTypeNumber, int length) {
            this.addressTypeNumber = (byte)addressTypeNumber;
            this.length = (byte)length;
        }

        static AddressType of(byte signedAddressType) throws ConnectException {
            int addressTypeNumber = Byte.toUnsignedInt(signedAddressType);
            for (AddressType addressType : AddressType.values()) {
                if (addressTypeNumber != addressType.getAddressTypeNumber()) continue;
                return addressType;
            }
            throw new ConnectException("Reply from SOCKS proxy server contains wrong address type Address type: " + addressTypeNumber);
        }

        byte getLength() {
            return this.length;
        }

        byte getAddressTypeNumber() {
            return this.addressTypeNumber;
        }
    }

    static enum SocksCommand {
        CONNECT(1);

        private final byte value;

        private SocksCommand(int value) {
            this.value = (byte)value;
        }

        public byte getCommandNumber() {
            return this.value;
        }
    }

    static enum ServerReply {
        REPLY_SUCCEEDED(0, "Succeeded"),
        GENERAL_FAILURE(1, "General SOCKS5 server failure"),
        NOT_ALLOWED(2, "Connection is not allowed by ruleset"),
        NET_UNREACHABLE(3, "Network is unreachable"),
        HOST_UNREACHABLE(4, "Host is unreachable"),
        CONN_REFUSED(5, "Connection has been refused"),
        TTL_EXPIRED(6, "TTL is expired"),
        CMD_NOT_SUPPORTED(7, "Command is not supported"),
        ADDR_TYPE_NOT_SUP(8, "Address type is not supported");

        private final int replyNumber;
        private final String message;

        private ServerReply(int replyNumber, String message) {
            this.replyNumber = replyNumber;
            this.message = message;
        }

        static ServerReply of(byte byteStatus) throws ConnectException {
            int status = Byte.toUnsignedInt(byteStatus);
            for (ServerReply serverReply : ServerReply.values()) {
                if (status != serverReply.replyNumber) continue;
                return serverReply;
            }
            throw new ConnectException("Unknown reply field. Reply field: " + status);
        }

        public String getMessage() {
            return this.message;
        }
    }
}

