001/*
002 *  Copyright 2016 Anyware Services
003 *
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package org.ametys.core.script;
017
018import java.io.IOException;
019import java.io.InputStream;
020import java.io.LineNumberReader;
021import java.io.StringReader;
022import java.sql.Connection;
023import java.sql.DatabaseMetaData;
024import java.sql.ResultSet;
025import java.sql.SQLException;
026import java.sql.Statement;
027import java.util.Map;
028
029import org.apache.commons.io.IOUtils;
030import org.apache.excalibur.source.Source;
031import org.apache.excalibur.source.SourceResolver;
032import org.slf4j.Logger;
033import org.slf4j.LoggerFactory;
034
035import org.ametys.core.datasource.ConnectionHelper;
036
037/**
038 * Example of simple use: 
039 * SQLScriptHelper.createTableIfNotExists(dataSourceId, "QRTZ_JOB_DETAILS", "plugin:core://scripts/%s/quartz.sql", _sourceResolver);
040 * Will test if table QRTZ_JOB_DETAILS exits in database from datasource dataSourceId. If not, the script plugin:core://scripts/%s/quartz.sql will be resolved and executed (where %s is replaced by the database type 'mysql', 'derby'...)
041 * 
042 * Tools to run SQL scripts.<p>
043 * Default separator for isolating statements is the semi colon
044 * character: <code>;</code>.<br>
045 * It can be changed by using a comment like the following snippet
046 * for using the string <code>---</code>:<br>
047 * <code>-- _separator_=---<br>
048 * begin<br>
049 * &nbsp;&nbsp;execute immediate 'DROP TABLE MYTABLE';<br>
050 * &nbsp;&nbsp;Exception when others then null;<br>
051 * end;<br>
052 * ---<br>
053 * -- _separator_=;<br>
054 * CREATE TABLE MYTABLE;<br>
055 * ...</code><br>
056 * Note that the command must be placed at the end of the comment.<br><br>
057 * The runner can be configured to ignore SQLExceptions. This can be useful
058 * to execute DROP statements when it's unknown if the tables exist:<br>
059 * <code>--_ignore_exceptions_=on<br>
060 * DROP TABLE MYTABLE;<br>
061 * --_ignore_exceptions_=off</code>
062 */
063public final class SQLScriptHelper
064{
065    /** Default separator used for isolating statements. */
066    public static final String DEFAULT_SEPARATOR = ";";
067    /** Command to ignore sql exceptions. */
068    public static final String IGNORE_EXCEPTIONS_COMMAND = "_ignore_exceptions_=";
069    /** Command to change the separator. */
070    public static final String CHANGE_SEPARATOR_COMMAND = "_separator_=";
071    /** Logger available to subclasses. */
072    protected static final Logger __LOGGER = LoggerFactory.getLogger(SQLScriptHelper.class);
073    
074    private SQLScriptHelper()
075    {
076        // Nothing to do
077    }
078
079    /**
080     * This method will test if a table exists, and if not will execute a script to create it
081     * @param datasourceId The data source id to open a connection to the database
082     * @param tableNameToCheck The name of the table that will be checked
083     * @param location The source location where to find the script to execute to create the table. This string will be format with String.format with the dbType as argument.
084     * @param sourceResolver The source resolver
085     * @return true if the table was created, false otherwise
086     * @throws SQLException If an error occurred while executing SQL script, or while testing table existence
087     * @throws IOException If an error occurred while getting the script file, or if the url is malformed
088     */
089    public static boolean createTableIfNotExists(String datasourceId, String tableNameToCheck, String location, SourceResolver sourceResolver) throws SQLException, IOException
090    {
091        return createTableIfNotExists(datasourceId, tableNameToCheck, location, sourceResolver, null);
092    }
093
094    /**
095     * This method will test if a table exists, and if not will execute a script to create it
096     * @param datasourceId The data source id to open a connection to the database
097     * @param tableNameToCheck The name of the table that will be checked
098     * @param location The source location where to find the script to execute to create the table. This string will be format with String.format with the dbType as argument.
099     * @param sourceResolver The source resolver
100     * @param replace The map of string to replace. Key is the regexp to seek, value is the replacing string.
101     * @return true if the table was created, false otherwise
102     * @throws SQLException If an error occurred while executing SQL script, or while testing table existence
103     * @throws IOException If an error occurred while getting the script file, or if the url is malformed
104     */
105    public static boolean createTableIfNotExists(String datasourceId, String tableNameToCheck, String location, SourceResolver sourceResolver, Map<String, String> replace) throws SQLException, IOException
106    {
107        Connection connection = null;
108        try
109        {
110            connection = ConnectionHelper.getConnection(datasourceId);
111            
112            return createTableIfNotExists(connection, tableNameToCheck, location, sourceResolver, replace);
113        }
114        finally
115        {
116            ConnectionHelper.cleanup(connection);
117        }
118    }
119    
120    /**
121     * This method will test if a table exists, and if not will execute a script to create it
122     * @param connection The database connection to use
123     * @param tableNameToCheck The name of the table that will be checked
124     * @param location The source location where to find the script to execute to create the table. This string will be format with String.format with the dbType as argument.
125     * @param sourceResolver The source resolver
126     * @return true if the table was created, false otherwise
127     * @throws SQLException If an error occurred while executing SQL script, or while testing table existence
128     * @throws IOException If an error occurred while getting the script file, or if the url is malformed
129     */
130    public static boolean createTableIfNotExists(Connection connection, String tableNameToCheck, String location, SourceResolver sourceResolver) throws SQLException, IOException
131    {
132        return createTableIfNotExists(connection, tableNameToCheck, location, sourceResolver, null);
133    }
134    
135    /**
136     * This method will test if a table exists, and if not will execute a script to create it
137     * @param connection The database connection to use
138     * @param tableNameToCheck The name of the table that will be checked
139     * @param location The source location where to find the script to execute to create the table. This string will be format with String.format with the dbType as argument.
140     * @param sourceResolver The source resolver
141     * @param replace The map of string to replace. Key is the regexp to seek, value is the replacing string.
142     * @return true if the table was created, false otherwise
143     * @throws SQLException If an error occurred while executing SQL script, or while testing table existence
144     * @throws IOException If an error occurred while getting the script file, or if the url is malformed
145     */
146    public static boolean createTableIfNotExists(Connection connection, String tableNameToCheck, String location, SourceResolver sourceResolver, Map<String, String> replace) throws SQLException, IOException
147    {
148        if (tableExists(connection, tableNameToCheck))
149        {
150            return false;
151        }
152        
153        String finalLocation = String.format(location, ConnectionHelper.getDatabaseType(connection));
154        
155        Source source = null;
156        try
157        {
158            source = sourceResolver.resolveURI(finalLocation);
159            
160            try (InputStream is = source.getInputStream())
161            {
162                String script = IOUtils.toString(is, "UTF-8");
163
164                if (replace != null)
165                {
166                    for (String replaceKey : replace.keySet())
167                    {
168                        script = script.replaceAll(replaceKey, replace.get(replaceKey));
169                    }
170                }
171                
172                SQLScriptHelper.runScript(connection, script);
173            }
174        }
175        finally
176        {
177            sourceResolver.release(source);
178        }
179        
180        return true;
181    }
182    
183    /**
184     * Checks whether the given table exists in the database.
185     * @param connection The database connection
186     * @param tableName the name of the table
187     * @return true is the table exists
188     * @throws SQLException In an SQL exception occurs
189     */
190    public static boolean tableExists(Connection connection, String tableName) throws SQLException
191    {
192        ResultSet rs = null;
193        
194        DatabaseMetaData metaData = connection.getMetaData();
195        
196        // Test for non escaped table names
197        String name = tableName;
198        if (metaData.storesLowerCaseIdentifiers())
199        {
200            name = tableName.toLowerCase();
201        }
202        else if (metaData.storesUpperCaseIdentifiers())
203        {
204            name = tableName.toUpperCase();
205        }
206        
207        try
208        {
209            rs = metaData.getTables(connection.getCatalog(), connection.getSchema(), name, null);
210            if (rs.next())
211            {
212                return true;
213            }
214        }
215        finally
216        {
217            ConnectionHelper.cleanup(rs);
218        }
219        
220        // Test for escaped table names
221        String quotedName = tableName;
222        if (metaData.storesLowerCaseQuotedIdentifiers())
223        {
224            quotedName = tableName.toLowerCase();
225        }
226        else if (metaData.storesUpperCaseQuotedIdentifiers())
227        {
228            quotedName = tableName.toUpperCase();
229        }
230        
231        if (!quotedName.equals(name))
232        {
233            try
234            {
235                rs = metaData.getTables(connection.getCatalog(), connection.getSchema(), quotedName, null);
236                if (rs.next())
237                {
238                    return true;
239                }
240            }
241            finally
242            {
243                ConnectionHelper.cleanup(rs);
244            }
245        }
246        
247        return false;
248    }
249
250    /**
251     * Run a SQL script using the connection passed in.
252     * @param connection the connection to use for the script
253     * @param script the script data.
254     * @throws IOException if an error occurs while reading the script.
255     * @throws SQLException if an error occurs while executing the script.
256     */
257    public static void runScript(Connection connection, String script) throws IOException, SQLException
258    {
259        ScriptContext scriptContext = new ScriptContext();
260        StringBuilder command = new StringBuilder();
261        
262        try
263        {
264            LineNumberReader lineReader = new LineNumberReader(new StringReader(script));
265            String line = null;
266            while ((line = lineReader.readLine()) != null)
267            {
268                if (__LOGGER.isDebugEnabled())
269                {
270                    __LOGGER.debug(String.format("Reading line: '%s'", line));
271                }
272                
273                boolean processCommand = false;
274                String trimmedLine = line.trim();
275                
276                if (trimmedLine.length() > 0)
277                {
278                    processCommand = processScriptLine(trimmedLine, command, scriptContext);
279                    
280                    if (processCommand)
281                    {
282                        _processCommand(connection, command, lineReader.getLineNumber(), scriptContext);
283                    }
284                }
285            }
286            
287            // If the entire file was processed and the command buffer is not empty, execute the current buffer.
288            if (command.length() > 0)
289            {
290                _processCommand(connection, command, lineReader.getLineNumber(), scriptContext);
291            }
292            
293            if (!connection.getAutoCommit())
294            {
295                connection.commit();
296            }
297        }
298        finally
299        {
300            if (!connection.getAutoCommit())
301            {
302                try
303                {
304                    // Fermer la connexion à la base
305                    connection.rollback();
306                }
307                catch (SQLException s)
308                {
309                    __LOGGER.error("Error while rollbacking connection", s);
310                }
311            }
312        }
313    }
314    
315    /**
316     * Run a SQL script using the connection passed in.
317     * @param connection the connection to use for the script
318     * @param is the input stream containing the script data.
319     * @throws IOException if an error occurs while reading the script.
320     * @throws SQLException if an error occurs while executing the script.
321     */
322    public static void runScript(Connection connection, InputStream is) throws IOException, SQLException
323    {
324        try
325        {
326            String script = IOUtils.toString(is, "UTF-8");
327            runScript(connection, script);
328        }
329        finally
330        {
331            IOUtils.closeQuietly(is);
332        }
333    }
334
335    /**
336     * Process a script line.
337     * @param line the line to process.
338     * @param commandBuffer the command buffer.
339     * @param scriptContext the script execution context.
340     * @return true to immediately process the command (a separator was found), false to process it later.
341     */
342    protected static boolean processScriptLine(String line, StringBuilder commandBuffer, ScriptContext scriptContext)
343    {
344        boolean processCommand = false;
345        
346        if (line.startsWith("//") || line.startsWith("--"))
347        {
348            String currentSeparator = scriptContext.getSeparator();
349            
350            // Search if the separator needs to be changed
351            if (line.contains(CHANGE_SEPARATOR_COMMAND))
352            {
353                // New separator
354                String newSeparator = line.substring(line.indexOf(CHANGE_SEPARATOR_COMMAND)
355                            + CHANGE_SEPARATOR_COMMAND.length()).trim();
356                
357                scriptContext.setSeparator(newSeparator);
358                
359                if (__LOGGER.isDebugEnabled())
360                {
361                    __LOGGER.debug(String.format("Changing separator to: '%s'", newSeparator));
362                }
363            }
364            else if (line.contains(IGNORE_EXCEPTIONS_COMMAND))
365            {
366                String ignoreStr = line.substring(line.indexOf(IGNORE_EXCEPTIONS_COMMAND)
367                            + IGNORE_EXCEPTIONS_COMMAND.length()).trim();
368                
369                boolean ignoreExceptions = "on".equals(ignoreStr);
370                
371                scriptContext.setIgnoreExceptions(ignoreExceptions);
372                
373                if (__LOGGER.isDebugEnabled())
374                {
375                    __LOGGER.debug(String.format("Ignore exceptions: '%s'", ignoreExceptions ? "on" : "off"));
376                }
377            }
378            
379            if (line.contains(currentSeparator))
380            {
381                if (commandBuffer.length() > 0)
382                {
383                    // End of command but do not use current line
384                    processCommand = true;
385                }
386            }
387        }
388        else if (line.endsWith(scriptContext.getSeparator()))
389        {
390            // End of command and use current line
391            processCommand = true;
392            commandBuffer.append(line.substring(0, line.lastIndexOf(scriptContext.getSeparator())));
393        }
394        else
395        {
396            // Append current command to the buffer
397            commandBuffer.append(line);
398            commandBuffer.append(" ");
399        }
400        
401        return processCommand;
402    }
403    
404    private static void _processCommand(Connection connection, StringBuilder command, int lineNumber, ScriptContext scriptContext) throws SQLException
405    {
406        if (__LOGGER.isInfoEnabled())
407        {
408            __LOGGER.info(String.format("Executing SQL command: '%s'", command));
409        }
410        
411        _execute(connection, command.toString(), lineNumber, scriptContext);
412
413        // Clear command
414        command.setLength(0);
415    }
416    
417    private static void _execute(Connection connection, String command, int lineNumber, ScriptContext scriptContext) throws SQLException
418    {
419        Statement statement = null;
420        try
421        {
422            statement = connection.createStatement();
423            statement.execute(command);
424        }
425        catch (SQLException e)
426        {
427            if (!scriptContext.ignoreExceptions())
428            {
429                String message = String.format("Unable to execute SQL: '%s' at line %d", command, lineNumber);
430                __LOGGER.error(message, e);
431                
432                throw new SQLException(message, e);
433            }
434        }
435        finally
436        {
437            ConnectionHelper.cleanup(statement);
438        }
439    }
440    
441    /**
442     * Script execution context.
443     */
444    protected static class ScriptContext
445    {
446        
447        /** The current script execution block separator. */
448        protected String _separator;
449        
450        /** True to ignore sql exceptions. */
451        protected boolean _ignoreExceptions;
452        
453        /**
454         * Default ScriptContext object.
455         */
456        public ScriptContext()
457        {
458            this(DEFAULT_SEPARATOR, false);
459        }
460        
461        /**
462         * Build a ScriptContext object.
463         * @param separator the separator
464         * @param ignoreExceptions true to ignore exceptions.
465         */
466        public ScriptContext(String separator, boolean ignoreExceptions)
467        {
468            this._separator = separator;
469            this._ignoreExceptions = ignoreExceptions;
470        }
471        
472        /**
473         * Get the separator.
474         * @return the separator
475         */
476        public String getSeparator()
477        {
478            return _separator;
479        }
480        
481        /**
482         * Set the separator.
483         * @param separator the separator to set
484         */
485        public void setSeparator(String separator)
486        {
487            this._separator = separator;
488        }
489        
490        /**
491         * Get the ignoreExceptions.
492         * @return the ignoreExceptions
493         */
494        public boolean ignoreExceptions()
495        {
496            return _ignoreExceptions;
497        }
498        
499        /**
500         * Set the ignoreExceptions.
501         * @param ignoreExceptions the ignoreExceptions to set
502         */
503        public void setIgnoreExceptions(boolean ignoreExceptions)
504        {
505            this._ignoreExceptions = ignoreExceptions;
506        }
507        
508    }
509    
510}