将mysql数据库的建表语句修改成green plum数据库中可执行的脚本

#用来获取输入的对应的表名称的建表语句
#首先输入需要获取的mysql或者sql server的数据库表名称,到对应的数据库中抓出建表语句,
#然后,将建表语句进行对应的修改,形成可以在pg中可用的语句

#连接mysql数据
import pymysql
import sys
import re


class MysqlDB:
    def __init__(self):
        self.hostname = ''
        self.port = 3306
        self.username = ''
        self.password = ''
        self.database = ''
    def connectmysql(self):
        try:
            condb = pymysql.connect(host=self.hostname, port=self.port, user=self.username, passwd=self.password, db=self.database)
            return  condb
        except Exception:
            info = sys.exc_info()
            print("连接异常", info[1])


    def insertTable(self):
        # 获取数据库连接
        condb = self.connectmysql()
        # 使用cursor() 方法创建一个游标对象 cursor
        cursor = condb.cursor()
        try:
            # 执行sql语句
            sql = "select * from cn_customer.bank"
            cursor.execute(sql)
            # 提交到数据库执行
            condb.commit()
        except Exception:  # 方法一:捕获所有异常
            # 如果发生异常,则回滚
            info = sys.exc_info()
            print("发生异常", info[1])
            condb.rollback()
        finally:
            # 最终关闭数据库连接
            condb.close()

    def _convert_type(self, data_type):
        """Normalize MySQL `data_type`"""
        # if 'varchar' in data_type:
        #     return 'varchar'
        if 'int' in  data_type:
            return 'int4'
        # elif 'char' in data_type:
        #     return 'char'
        elif data_type in ('bit(1)', 'tinyint(1)', 'tinyint(1) unsigned'):
            return 'boolean'
        elif re.search(r'smallint.* unsigned', data_type) or 'mediumint' in data_type:
            return 'integer'
        elif 'smallint' in data_type:
            return 'tinyint'
        elif 'tinyint' in data_type or 'year(' in data_type:
            return 'tinyint'
        elif 'bigint' in data_type and 'unsigned' in data_type:
            return 'numeric'
        elif re.search(r'int.* unsigned', data_type) or \
                ('bigint' in data_type and 'unsigned' not in data_type):
            return 'bigint'
        elif 'float' in data_type:
            return 'float'
        elif 'decimal' in data_type:
            return 'decimal'
        elif 'double' in data_type:
            return 'double precision'
        else:
            return data_type

#获取的table_columns名称,长度等信息
    def load_columns(self,table_name):
        # 获取数据库连接
        condb = self.connectmysql()
        # 使用cursor() 方法创建一个游标对象 cursor
        cursor = condb.cursor()
        fields = []
        cursor.execute('EXPLAIN `%s`' % table_name)
        table_info = cursor.fetchall()
        for res in table_info:
            if res[2] == 'YES':
                table_null = 'not null'
            else:
                table_null = ''
            desc = {
                'column_name': res[0].lower(),
                'table_name': table_name.lower(),
                'type': self._convert_type(res[1]).lower(),
                # 'length': int(length) if length else None,
                # 'decimals': precision_match.group(2) if precision_match else None,
                'null': table_null ,
                'primary_key': res[3] == 'PRI',
                # 'auto_increment': res[5] == 'auto_increment',
                # 'default': res[4] if not res[4] == 'NULL' else None,
            }

            fields.append(desc)
        #print(fields)
        self.postgres_create(fields)

#在postgresql中创建表
    def postgres_create(self,fields):
        table_name = 'dw_stg.stg_cus_dim_'+ fields[0]['table_name']
        columns = []
        primary_key=[]
        for field in fields:
            if field['primary_key']:
                primary_key.append(field['column_name'])
            table_column = field['column_name'] + ' ' + field['type'] + ' ' + field['null'] + ','+ '\n'
            columns.append(table_column)
        create_columns = ''.join(columns) + 'ord_loc varchar(10),\nmodify_date_etl timestamp default now(),\nload_dt timestamp default now(),\n'
        create_primary_key = ','.join(primary_key)
        create_sql = "create table %s (\n%sprimary key(%s)\n)\ndistributed by (%s)" %(table_name,create_columns,create_primary_key,create_primary_key)
        print(create_sql)


def main():
    mysqldb = MysqlDB()
    connect_result = mysqldb.load_columns('bank')

main()

猜你喜欢

转载自blog.csdn.net/qq_22994783/article/details/82879068