/*
 *  Copyright 2022 Anyware Services
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.ametys.plugins.extrausermgt.authentication.kerberos;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.login.AppConfigurationEntry;
import javax.security.auth.login.Configuration;
import javax.security.auth.login.LoginContext;
import javax.security.auth.login.LoginException;

import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSCredential;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.GSSName;
import org.ietf.jgss.Oid;

import org.ametys.runtime.model.checker.ItemChecker;
import org.ametys.runtime.model.checker.ItemCheckerTestFailureException;
import org.ametys.runtime.plugin.component.AbstractLogEnabled;

import com.google.common.net.InetAddresses;

/**
 * This checks that the parameters are the one of a Kerberos server
 */
public class KerberosChecker extends AbstractLogEnabled implements ItemChecker
{
    public void check(List<String> values) throws ItemCheckerTestFailureException
    {
        String realm = values.get(0);
        String svcLogin = values.get(1);
        String svcPassword = values.get(2);
        String kdc = values.get(3);
        String testDomain = values.get(4);
        String testLogin = values.get(5);
        String testPassword = values.get(6);
        
        try
        {
            System.setProperty("java.security.krb5.kdc", kdc);
            
            Configuration loginConfig = null;
            if (System.getProperty("java.security.auth.login.config") == null)
            {
                loginConfig = new Configuration() 
                {
                    @Override
                    public AppConfigurationEntry[] getAppConfigurationEntry(String name) 
                    {
                        return new AppConfigurationEntry[] {new AppConfigurationEntry("com.sun.security.auth.module.Krb5LoginModule", AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, Map.of())};
                    }
                };
            }
            
            LoginContext loginContext = new LoginContext("kerberos-client", null, new CallbackHandler()
            {
                public void handle(final Callback[] callbacks)
                {
                    for (Callback callback : callbacks)
                    {
                        if (callback instanceof NameCallback)
                        {
                            ((NameCallback) callback).setName(testLogin + "@" + realm.toUpperCase());
                        }
                        else if (callback instanceof PasswordCallback)
                        {
                            ((PasswordCallback) callback).setPassword(testPassword.toCharArray());
                        }
                        else
                        {
                            throw new RuntimeException("Invalid callback received during KerberosCredentialProvider initialization");
                        }
                    }
                }
            }, loginConfig);
          
            getLogger().debug("***** Authenticating " + testLogin);
            
            loginContext.login();
            Subject subject = loginContext.getSubject();
            
            getLogger().debug("***** TGT obtained");
            getLogger().debug(subject.toString());

            GSSManager manager = GSSManager.getInstance();
            
            Callable<GSSCredential> action = new Callable<>() 
            {
                public GSSCredential call() throws GSSException 
                {
                    return manager.createCredential(null, GSSCredential.INDEFINITE_LIFETIME, new Oid("1.3.6.1.5.5.2"), GSSCredential.INITIATE_ONLY);
                } 
            };
            
            GSSCredential gssCredential = Subject.callAs(subject, action);
            
            String receivedToken = null;
            try
            {
                receivedToken = _getToken(manager, testDomain, realm, gssCredential);
            }
            catch (GSSException e)
            {
                if (e.getMajor() == 13) // no valid credentials, possibly due to wrong SPN
                {
                    String resolvedDomain = null;
                    
                    try
                    {
                        resolvedDomain = InetAddress.getByName(testDomain).getCanonicalHostName();
                    }
                    catch (UnknownHostException ex)
                    {
                        getLogger().debug("***** Cannot get ticket for host {} and also fail to resolve", testDomain, ex);
                        throw e; // rethrow the initial exception
                    }
                    
                    if (InetAddresses.isInetAddress(resolvedDomain) || resolvedDomain.equals(testDomain))
                    {
                        // reverse DNS not set or resolved to the same host => nothing to do
                        throw e;
                    }
                    
                    getLogger().debug("***** Cannot get ticket for host {}, try with {}", resolvedDomain);
                    receivedToken = _getToken(manager, resolvedDomain, realm, gssCredential);
                }
                else
                {
                    throw e;
                }
            }

            getLogger().debug("***** Decoding token");
            
            LoginContext srvLoginContext = KerberosCredentialProvider.createLoginContext(realm, svcLogin, svcPassword);
            
            action = new Callable<>() 
            {
                public GSSCredential call() throws GSSException 
                {
                    return manager.createCredential(null, GSSCredential.INDEFINITE_LIFETIME, new Oid("1.3.6.1.5.5.2"), GSSCredential.ACCEPT_ONLY);
                } 
            };
            
            gssCredential = Subject.callAs(srvLoginContext.getSubject(), action);
            GSSContext gssContext = GSSManager.getInstance().createContext(gssCredential);

            byte[] token = java.util.Base64.getDecoder().decode(receivedToken);
          
            gssContext.acceptSecContext(token, 0, token.length);

            GSSName gssSrcName = gssContext.getSrcName();
            getLogger().debug("***** User authenticated: " + gssSrcName);
        }
        catch (LoginException | GSSException e)
        {
            throw new ItemCheckerTestFailureException("Unable to connect to the KDC (" + e.getMessage() + ")", e);
        }
    }
    
    private String _getToken(GSSManager manager, String host, String realm, GSSCredential gssCredential) throws GSSException
    {
        getLogger().debug("***** Getting ticket for {}", host);
        
        GSSName peer = manager.createName("HTTP/" + host + "@" + realm.toUpperCase(), GSSName.NT_USER_NAME);
        GSSContext gssContext = GSSManager.getInstance().createContext(peer, new Oid("1.3.6.1.5.5.2"), gssCredential, GSSContext.INDEFINITE_LIFETIME);
        
        byte[] kdcTokenAnswer = gssContext.initSecContext(new byte[0], 0, 0);
        String receivedToken = kdcTokenAnswer != null ? java.util.Base64.getEncoder().encodeToString(kdcTokenAnswer) : null;

        getLogger().debug("***** Token generated\n{}", receivedToken);
        return receivedToken;
    }
}
