/*
 *  Copyright 2024 Anyware Services
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.ametys.plugins.extrausermgt.authentication.msal;

import java.io.IOException;
import java.net.URI;
import java.util.Date;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

import org.apache.avalon.framework.context.Context;
import org.apache.avalon.framework.context.ContextException;
import org.apache.avalon.framework.context.Contextualizable;
import org.apache.cocoon.ProcessingException;
import org.apache.cocoon.components.ContextHelper;
import org.apache.cocoon.environment.ObjectModelHelper;
import org.apache.cocoon.environment.Redirector;
import org.apache.cocoon.environment.Request;
import org.apache.cocoon.environment.Session;

import org.ametys.core.authentication.AbstractCredentialProvider;
import org.ametys.core.authentication.BlockingCredentialProvider;
import org.ametys.core.authentication.NonBlockingCredentialProvider;
import org.ametys.core.user.UserIdentity;
import org.ametys.plugins.extrausermgt.authentication.oidc.AbstractOIDCCredentialProvider;
import org.ametys.plugins.extrausermgt.authentication.oidc.OIDCBasedCredentialProvider;
import org.ametys.runtime.authentication.AccessDeniedException;
import org.ametys.workspaces.extrausermgt.authentication.oidc.OIDCCallbackAction;

import com.microsoft.aad.msal4j.AuthorizationCodeParameters;
import com.microsoft.aad.msal4j.AuthorizationRequestUrlParameters;
import com.microsoft.aad.msal4j.AuthorizationRequestUrlParameters.Builder;
import com.microsoft.aad.msal4j.ClientCredentialFactory;
import com.microsoft.aad.msal4j.ConfidentialClientApplication;
import com.microsoft.aad.msal4j.IAccount;
import com.microsoft.aad.msal4j.IAuthenticationResult;
import com.microsoft.aad.msal4j.IClientSecret;
import com.microsoft.aad.msal4j.Prompt;
import com.microsoft.aad.msal4j.ResponseMode;
import com.microsoft.aad.msal4j.SilentParameters;
import com.nimbusds.jwt.SignedJWT;

/**
 * Sign in through Entra ID, using the OpenId Connect protocol.
 */
public abstract class AbstractMSALCredentialProvider extends AbstractCredentialProvider implements OIDCBasedCredentialProvider, BlockingCredentialProvider, NonBlockingCredentialProvider, Contextualizable
{
    /** Session attribute to store the access token */
    public static final String ACCESS_TOKEN_SESSION_ATTRIBUTE = "msal_token";
    private static final String __ATTRIBUTE_EXPIRATIONDATE = "msal_expirationDate";
    private static final String __ATTRIBUTE_ACCOUNT = "msal_account";
    private static final String __ATTRIBUTE_TOKENCACHE = "msal_tokenCache";
    private static final String __ATTRIBUTE_CODE = "msal_code";
    private static final String __ATTRIBUTE_SILENT = "msal_silent";
    private static final String __ATTRIBUTE_STATE = "msal_state";
    private static final String __ATTRIBUTE_NONCE = "msal_nonce";
    
    /** the OIDC app id */
    protected String _clientID;
    /** the client secret */
    protected String _clientSecret;
    /** whether the user should be explicitely forced to enter its username */
    protected boolean _prompt;
    /** whether we should try to silently log the user in */
    protected boolean _silent;

    private Context _context;
    
    @Override
    public void contextualize(Context context) throws ContextException
    {
        _context = context;
    }
    
    /**
     * Set the mandatory properties. Should be called by implementors as early as possible.
     * @param cliendId the OIDC app id
     * @param clientSecret the client secret
     * @param prompt whether the user should be explicitely forced to enter its username
     * @param silent whether we should try to silently log the user in
     */
    protected void init(String cliendId, String clientSecret, boolean prompt, boolean silent)
    {
        _clientID = cliendId;
        _clientSecret = clientSecret;
        _prompt = prompt;
        _silent = silent;
    }
    
    private ConfidentialClientApplication _getClient() throws Exception
    {
        IClientSecret secret = ClientCredentialFactory.createFromSecret(_clientSecret);
        ConfidentialClientApplication client = ConfidentialClientApplication.builder(_clientID, secret)
                                                                            .authority(getAuthority())
                                                                            .build();
        return client;
    }
    
    /**
     * Returns the URL to send authorization and token requests to.
     * @return the OIDC authority URL
     */
    protected abstract String getAuthority();
    
    public String getClientId()
    {
        return _clientID;
    }
    
    @Override
    public boolean blockingIsStillConnected(UserIdentity userIdentity, Redirector redirector) throws Exception
    {
        Map objectModel = ContextHelper.getObjectModel(_context);
        Request request = ObjectModelHelper.getRequest(objectModel);
        Session session = request.getSession(true);
        
        refreshTokenIfNeeded(session);
        
        return true;
    }
    
    @Override
    public boolean nonBlockingIsStillConnected(UserIdentity userIdentity, Redirector redirector) throws Exception
    {
        return blockingIsStillConnected(userIdentity, redirector);
    }
    
    @Override
    public boolean blockingGrantAnonymousRequest()
    {
        return false;
    }
    
    @Override
    public boolean nonBlockingGrantAnonymousRequest()
    {
        return false;
    }
    
    private String _getRequestURI(Request request)
    {
        StringBuilder uriBuilder = new StringBuilder();
        if (request.isSecure())
        {
            uriBuilder.append("https://").append(request.getServerName());
            if (request.getServerPort() != 443)
            {
                uriBuilder.append(":");
                uriBuilder.append(request.getServerPort());
            }
        }
        else
        {
            uriBuilder.append("http://").append(request.getServerName());
            if (request.getServerPort() != 80)
            {
                uriBuilder.append(":");
                uriBuilder.append(request.getServerPort());
            }
        }
        
        uriBuilder.append(request.getContextPath());
        uriBuilder.append(OIDCCallbackAction.CALLBACK_URL);
        return uriBuilder.toString();
    }

    private UserIdentity _login(boolean silent, Redirector redirector) throws Exception
    {
        Map objectModel = ContextHelper.getObjectModel(_context);
        Request request = ObjectModelHelper.getRequest(objectModel);
        Session session = request.getSession(true);
        
        ConfidentialClientApplication client = _getClient();

        String requestURI = _getRequestURI(request);
        getLogger().debug("MSAL CredentialProvider callback URI: {}", requestURI);

        String storedCode = (String) session.getAttribute(__ATTRIBUTE_CODE);
        
        if (storedCode != null)
        {
            return _getUserIdentityFromCode(storedCode, session, client, requestURI);
        }
        
        boolean wasSilent = false;
        if (silent)
        {
            wasSilent = "true".equals(session.getAttribute(__ATTRIBUTE_SILENT));
        }
        
        String code = request.getParameter("code");
        if (code == null)
        {
            // sign-in request: redirect the client through the actual authentication process
            
            if (wasSilent)
            {
                // already passed through this, there should have been some error somewhere
                return null;
            }
            
            if (silent)
            {
                session.setAttribute(__ATTRIBUTE_SILENT, "true");
            }
            
            String state = UUID.randomUUID().toString();
            session.setAttribute(__ATTRIBUTE_STATE, state);
            
            String actualRedirectUri = request.getRequestURI();
            if (request.getQueryString() != null)
            {
                actualRedirectUri += "?" + request.getQueryString();
            }
            session.setAttribute(AbstractOIDCCredentialProvider.REDIRECT_URI_SESSION_ATTRIBUTE, actualRedirectUri);
            
            String nonce = UUID.randomUUID().toString();
            session.setAttribute(__ATTRIBUTE_NONCE, nonce);
            
            Builder builder = AuthorizationRequestUrlParameters.builder(requestURI, getScopes())
                                                               .responseMode(ResponseMode.QUERY)
                                                               .state(state)
                                                               .nonce(nonce);
            
            if (silent)
            {
                builder.prompt(Prompt.NONE);
            }
            else if (_prompt)
            {
                builder.prompt(Prompt.SELECT_ACCOUNT);
            }
            
            AuthorizationRequestUrlParameters parameters = builder.build();

            String authorizationRequestUrl = client.getAuthorizationRequestUrl(parameters).toString();
            redirector.redirect(false, authorizationRequestUrl);
            return null;
        }
        
        // we got an authorization code,
        
        // but first, check the state to prevent CSRF attacks
        String storedState = (String) session.getAttribute(__ATTRIBUTE_STATE);
        String state = request.getParameter("state");
        
        if (!storedState.equals(state))
        {
            throw new AccessDeniedException("MSAL state mismatch");
        }
        
        session.setAttribute(__ATTRIBUTE_STATE, null);
        
        // then store the authorization code
        session.setAttribute(__ATTRIBUTE_CODE, code);
        
        // and finally redirect to initial URI
        String redirectUri = (String) session.getAttribute(AbstractOIDCCredentialProvider.REDIRECT_URI_SESSION_ATTRIBUTE);
        redirector.redirect(true, redirectUri);
        return null;
    }
    
    /**
     * Returns all needed OIDC scopes. Defaults to ["openid"]
     * @return all needed OIDC scopes
     */
    protected Set<String> getScopes()
    {
        return Set.of("openid");
    }
    
    private UserIdentity _getUserIdentityFromCode(String code, Session session, ConfidentialClientApplication client, String requestURI) throws Exception
    {
        AuthorizationCodeParameters authParams = AuthorizationCodeParameters.builder(code, new URI(requestURI))
                                                                            .scopes(getScopes())
                                                                            .build();

        IAuthenticationResult result = client.acquireToken(authParams).get();
        
        // check nonce
        Map<String, Object> tokenClaims = SignedJWT.parse(result.idToken()).getJWTClaimsSet().getClaims();
        
        String storedNonce = (String) session.getAttribute(__ATTRIBUTE_NONCE);
        String nonce = (String) tokenClaims.get("nonce");
        
        if (!storedNonce.equals(nonce))
        {
            throw new AccessDeniedException("MSAL nonce mismatch");
        }
        
        session.setAttribute(__ATTRIBUTE_NONCE, null);
        
        session.setAttribute(__ATTRIBUTE_EXPIRATIONDATE, result.expiresOnDate());
        session.setAttribute(__ATTRIBUTE_TOKENCACHE, client.tokenCache().serialize());
        session.setAttribute(__ATTRIBUTE_ACCOUNT, result.account());
        
        session.setAttribute(ACCESS_TOKEN_SESSION_ATTRIBUTE, result.accessToken());
        
        // then the user is finally logged in
        String login = getLogin(result);
        
        return new UserIdentity(login, null);
    }
    
    /**
     * Retrieves the login from the given authentication result
     * @param result the authentication result
     * @return the login
     */
    protected String getLogin(IAuthenticationResult result)
    {
        return result.account().username();
    }
    
    @Override
    public UserIdentity blockingGetUserIdentity(Redirector redirector) throws Exception
    {
        return _login(false, redirector);
    }
    
    public UserIdentity nonBlockingGetUserIdentity(Redirector redirector) throws Exception
    {
        if (!_silent)
        {
            return null;
        }
        
        return _login(true, redirector);
    }
    
    @Override
    public void blockingUserNotAllowed(Redirector redirector)
    {
        // Nothing to do.
    }
    
    @Override
    public void nonBlockingUserNotAllowed(Redirector redirector) throws Exception
    {
        // Nothing to do.
    }

    @Override
    public void blockingUserAllowed(UserIdentity userIdentity, Redirector redirector) throws ProcessingException, IOException
    {
        // Nothing to do.
    }
    
    @Override
    public void nonBlockingUserAllowed(UserIdentity userIdentity, Redirector redirector)
    {
        // Empty method, nothing more to do.
    }

    public boolean requiresNewWindow()
    {
        return true;
    }

    /**
     * Refresh the access token of the user if needed
     * @param session the session
     * @throws Exception when an error occurs
     */
    public void refreshTokenIfNeeded(Session session) throws Exception
    {
        // this check is also done by the following MSAL code, but it's way faster with just a simple date check
        Date expDat = (Date) session.getAttribute(__ATTRIBUTE_EXPIRATIONDATE);
        if (expDat != null && new Date().after(expDat))
        {
            ConfidentialClientApplication client = _getClient();
            
            IAccount account = (IAccount) session.getAttribute(__ATTRIBUTE_ACCOUNT);
            String tokenCache = (String) session.getAttribute(__ATTRIBUTE_TOKENCACHE);
            
            SilentParameters parameters = SilentParameters.builder(Set.of("openid"), account).build();
            client.tokenCache().deserialize(tokenCache);
            IAuthenticationResult result = client.acquireTokenSilently(parameters).get();
            
            session.setAttribute(__ATTRIBUTE_EXPIRATIONDATE, result.expiresOnDate());
            session.setAttribute(__ATTRIBUTE_TOKENCACHE, client.tokenCache().serialize());
            session.setAttribute(__ATTRIBUTE_ACCOUNT, result.account());
            
            session.setAttribute(ACCESS_TOKEN_SESSION_ATTRIBUTE, result.accessToken());
        }
    }
}
