Files @ 7f3ee64c982a
Branch filter:

Location: hc-utils/utilities/csv_load.py

andy47
Coping with different database paramstyles in csv load
#!/usr/bin/python
"""
 Module  : csv_load.py
 License : BSD License (see LICENSE.txt)

This module loads the contents of a csv file into a database table.

Note that there are a few conditions, not all of which are adequately tested in
the code;

* There must be one field in each row of the file for each column in the table
* The order of the fields in the file must be the same as when you execute
  SELECT * FROM table. If in doubt use the --check flag to see what is
  expected
* If the csv file has a header row you must use the --skip flag

Arguments;
    1st - db connection string (See dburi.py for format)
    2nd - filename to load from
    3rd - table name

To get round the problem of large tables, I've utilised the ResultIter function
from the Python Cookbook
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/137270
Many thanks to Christopher Prinos for the marvelous code.

Release 1.0.1 has *only* been tested on Python 3.4 and above

TODO:: - Write unit tests
TODO:: - Make the module more database agnostic
"""

__version__ = (1, 0, 1)
__date__ = (2017, 1, 12)
__author__ = "Andy Todd <andy47@halfcooked.com>"

import argparse
import csv
import sys

from simple_log import get_log

from utilities.dburi import get_connection

debug = False


class InvalidFile(Exception):
    pass


def load(connection, paramstyle, file_name, table_name, skip, check, log=None):
    """Connect to connectionString, load file_name into table_name

    @param connection: Valid database connection
    @type connection: Database connection object
    @param paramstyle: Which parameter style to use?
    @type paramstyle: Valid DB-API 2.0 paramstyle, one of ['qmark', 'numeric', 'named', 'format', 'pyformat']
    @param file_name: Name of the csv file to load
    @type file_name: String
    @param table_name: Name of the table to load with the data from file_name
    @type file_name: String
    @param skip: Skip the first row of the csv file?
    @type skip: Boolean
    @param check: Just perform validation checks and don't load the data
    @type check: Boolean
    @return: Status indicator
    @rtype: Integer
    """
    if not log:
        log = get_log()
    cursor = connection.cursor()
    # Step 1, determine the columns in table_name
    cursor.execute("SELECT * FROM %s" % table_name)
    columns = [col[0] for col in cursor.description]
    with open(file_name, 'r') as csv_file:
        csv_reader = csv.reader(csv_file, dialect='excel')
        row = next(csv_reader)
        log.info("Opened file %s and read heading row" % file_name)
        if len(row) != len(columns):
            raise InvalidFile("Table {} has different number of columns to {}".
                              format(table_name, file_name))
        # Build our insert statement (use file column names rather than database)
        stmt = "INSERT INTO " + table_name + " (" + ",".join(row) + ")"
        if paramstyle == 'qmark':
            cols = "?,"*( len(columns)-1 )
            stmt += " VALUES (" + cols + "?)"
        elif paramstyle == 'format':
            cols = "%s,"*( len(columns)-1 )
            stmt += " VALUES (" + cols + "%s)"
        else:
            stmt += " VALUES ( :" + ", :".join(row) + ")"
        log.debug(stmt)
        record_count = 0
        if skip:
            row = next(csv_reader)
        for row in csv_reader:
            if not check:
                log.debug(row)
                cursor.execute(stmt, row)
            record_count += 1
            if record_count % 100 == 0:
                log.info("Loaded {} records".format(record_count))
        cursor.close()
        connection.commit()
        if check:
            log.info("Finished reading from file {} counted {} rows".
                     format(file_name, record_count))
        else:
            log.info("Finished reading from file {} inserted {} rows".
                     format(file_name, record_count))


def load_pyformat(connection, file_name, table_name, skip, check, log=None):
    """Connect to connectionString, load file_name into table_name

    I should really put this in the main load function, where it determines
    the proper thing to do based on the module's paramstyle.

    @param connection: Valid database connection
    @type connection: Database connection object
    @param file_name: Name of the csv file to load
    @type file_name: String
    @param table_name: Name of the table to load with the data from file_name
    @type file_name: String
    @param skip: Skip the first row of the csv file?
    @type skip: Boolean
    @param check: Just perform validation checks and don't load the data
    @type check: Boolean
    @param log: Log object to write info, errors and warnings to
    @return: Status indicator
    @rtype: Integer
    """
    if not log:
        log = get_log()
    cursor = connection.cursor()
    # Step 1, determine the columns in table_name
    cursor.execute("SELECT * FROM %s" % table_name)
    # This may be an issue on Sqlite where an empty table returns nothing
    columns = [col[0] for col in cursor.description]
    with open(file_name, 'r') as csv_file:
        csv_reader = csv.reader(csv_file)
        row = next(csv_reader)
        if len(row) != len(columns):
            raise InvalidFile("Table {} has different number of columns to {}".
                              format(table_name, file_name))
        # Formulate our insert statement
        stmt = "INSERT INTO " + table_name + " (" + ",".join(columns) + ")"
        stmt += " VALUES ( %(" + ")s, %(".join(columns) + ")s )"
        log.debug(stmt)
        record_count = 0
        if skip:
            row = next(csv_reader)
        for row in csv_reader:
            if not check:
                log.debug(row)
                cursor.execute(stmt, row)
            record_count += 1
            if record_count % 100 == 0:
                log.info("Loaded %d records" % record_count)
        cursor.close()
        connection.commit()
        if check:
            log.info("Finished reading from file {} counted {} rows".
                     format(file_name, record_count))
        else:
            log.info("Finished reading from file {} inserted {} rows".
                     format(file_name, record_count))


def main(argv=None):
    """
    Main function modelled on Guido's guidelines
    """
    if argv is None:
        argv = sys.argv
    # Set up our command line parser for optional and positional arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--debug", action="store_true", default=False)
    parser.add_argument("-s", "--skip", action="store_true", default=False)
    parser.add_argument("-c", "--check", action="store_true", default=False)
    parser.add_argument('conn_string', help='dburi connection string')
    parser.add_argument('input_file', help='CSV input file name')
    parser.add_argument('table_name', help='Name of table to dump')
    # Parse the command line
    args = parser.parse_args(argv[1:])
    # Set up logging
    log = get_log('csv_dump')
    if args.debug:
        log.setLevel('DEBUG')
    # What are we working with here exactly?
    log.debug("connection string {}".format(args.conn_string))
    log.debug("file name {}".format(args.input_file))
    log.debug("table name {}".format(args.table_name))
    connection = get_connection(args.conn_string)
    load(connection, args.input_file, args.table_name, args.skip, args.check,
         log=log)

if __name__ == "__main__":
    sys.exit(main())