RapidMiner5中读取数据库信息的算子read database源码解析

1 read database算子

在文件OperatorCore.xml中找到read database算子对应的类为DatabaseDataReader。

<operator>
	<key>read_database</key>
	<class>com.rapidminer.operator.io.DatabaseDataReader</class>
	<replaces>DatabaseExampleSource</replaces>
</operator>

1.1 DatabaseDataReader类图

DatabaseDataReader的类图如下。

1.2 ConnectionProvider

DatabaseDataReader实现了接口ConnectionProvider。

这个接口只有一个方法,即提供一个与连接信息有关的类ConnectionEntry。

public interface ConnectionProvider {
	public ConnectionEntry getConnectionEntry();
}

1.3 AbstractExampleSource与AbstractReader

DatabaseDataReader继承了类AbstractExampleSource,AbstractExampleSource又继承了AbstractReader。AbstractReader的其他方法不提,它的doWork()方法如下:

	@Override
	public void doWork() throws OperatorException {
		final T result = read();
		addAnnotations(result);
		outputPort.deliver(result);
	}

AbstractReader中没有提供read()的实现,AbstractExampleSource中read()中,实际调用的方法createExampleSet()并没有真正实现。

	/** Creates (or reads) the ExampleSet that will be returned by {@link #apply()}. */
	public abstract ExampleSet createExampleSet() throws OperatorException;

	@Override
	public ExampleSet read() throws OperatorException {
		return createExampleSet();
	}

1.4 DatabaseDataReader的createExampleSet()

DatabaseDataReader重写了read()方法,不过实际调用了父类的read(),然后利用工具类DatabaseHandler处理了一些连接异常和关闭连接的动作。因此DatabaseDataReader还必须实现createExampleSet()。

	private DatabaseHandler databaseHandler;

        @Override
	public ExampleSet read() throws OperatorException {
		try {
			ExampleSet result = super.read();
			return result;
		} finally {
			if (databaseHandler != null && databaseHandler.getConnection() != null) {
				try {
					databaseHandler.getConnection().close();
				} catch (SQLException e) {
					getLogger().log(Level.WARNING, "Error closing database connection: " + e, e);
				}
			}
		}
	}

在createExampleSet中,首先通过getResultSet()从实际数据库中获取数据。然后通过getAttribute()从数据中获取属性。最后将数据与属性放入RapidMiner的table中,再将它以exampleSet的格式返回。

MemoryExampleTable顾名思义,数据是存储在内存中。这里是否需要改进?比如表中数据实在过大,不需要一下子获取全部数据,而是通过翻页参数来获取某x行~某y行的数据。

	@Override
	public ExampleSet createExampleSet() throws OperatorException {
		ResultSet resultSet = getResultSet();
		MemoryExampleTable table;
		try {
			List<Attribute> attributes = getAttributes(resultSet);
			table = createExampleTable(resultSet, attributes, getParameterAsInt(ExampleSource.PARAMETER_DATAMANAGEMENT), getLogger());
		} catch (SQLException e) {
			throw new UserError(this, e, 304, e.getMessage());
		} finally {
			try {
				resultSet.close();
			} catch (SQLException e) {
				getLogger().log(Level.WARNING, "DB error closing result set: " + e, e);
			}
		}
		return table.createExampleSet();
	}

1.4.1 getResultSet()

首先看第一个方法getResultSet()。它首先通过工具类DatabaseHandler获取了一个连接的信息,然后又获取了执行的语句,最后执行这个语句。

	protected ResultSet getResultSet() throws OperatorException {
		try {
			databaseHandler = DatabaseHandler.getConnectedDatabaseHandler(this);
			String query = getQuery(databaseHandler.getStatementCreator());
			if (query == null) {
				throw new UserError(this, 202, new Object[] { "query", "query_file", "table_name" });
			}
			return databaseHandler.executeStatement(query, true, this, getLogger());
		} catch (SQLException sqle) {
			throw new UserError(this, sqle, 304, sqle.getMessage());
		}
	}

从算子的参数可以知道,read database算子的connection有三种选项:predefined,url和jdni。这里就是对三种场景进行处理,提取其中的连接信息,封装成DatabaseHandler进行返回。

	public static DatabaseHandler getConnectedDatabaseHandler(Operator operator) throws OperatorException, SQLException {
		switch (operator.getParameterAsInt(PARAMETER_DEFINE_CONNECTION)) {
			case CONNECTION_MODE_PREDEFINED:
				String repositoryName = null;
				if (operator.getProcess() != null) {
					RepositoryLocation repositoryLocation = operator.getProcess().getRepositoryLocation();
					if (repositoryLocation != null) {
						repositoryName = repositoryLocation.getRepositoryName();
					}
				}
				ConnectionEntry entry = DatabaseConnectionService.getConnectionEntry(operator.getParameterAsString(PARAMETER_CONNECTION), repositoryName);
				if (entry == null) {
					throw new UserError(operator, 318, operator.getParameterAsString(PARAMETER_CONNECTION));
				}
				return getConnectedDatabaseHandler(entry); //.getURL(), entry.getUser(), new String(entry.getPassword()));
			case DatabaseHandler.CONNECTION_MODE_JNDI:
				final String jndiName = operator.getParameterAsString(PARAMETER_JNDI_NAME);
				try {
					InitialContext ctx;
					ctx = new InitialContext();
					DataSource source = (DataSource) ctx.lookup(jndiName);
					return getHandler(source.getConnection());
				} catch (NamingException e) {
					throw new OperatorException("Failed to lookup '" + jndiName + "': " + e, e);
				}
			case DatabaseHandler.CONNECTION_MODE_URL:
			default:
				return getConnectedDatabaseHandler(operator.getParameterAsString(PARAMETER_DATABASE_URL),
						operator.getParameterAsString(PARAMETER_USERNAME),
						operator.getParameterAsString(PARAMETER_PASSWORD));
		}
	}

然后执行getQuery(),同样需要对三种场景进行处理,最终提取出query语句。

	private String getQuery(StatementCreator sc) throws OperatorException {
		switch (getParameterAsInt(DatabaseHandler.PARAMETER_DEFINE_QUERY)) {
			case DatabaseHandler.QUERY_QUERY: {
				String query = getParameterAsString(DatabaseHandler.PARAMETER_QUERY);
				if (query != null) {
					query = query.trim();
				}
				return query;
			}
			case DatabaseHandler.QUERY_FILE: {
				File queryFile = getParameterAsFile(DatabaseHandler.PARAMETER_QUERY_FILE);
				if (queryFile != null) {
					String query = null;
					try {
						query = Tools.readTextFile(queryFile);
					} catch (IOException ioe) {
						throw new UserError(this, ioe, 302, new Object[] { queryFile, ioe.getMessage() });
					}
					if (query == null || query.trim().length() == 0) {
						throw new UserError(this, 205, queryFile);
					}
					return query;
				}
			}
			case DatabaseHandler.QUERY_TABLE:
				TableName tableName = DatabaseHandler.getSelectedTableName(this);
				//final String tableName = getParameterAsString(DatabaseHandler.PARAMETER_TABLE_NAME);
				return "SELECT * FROM " + sc.makeIdentifier(tableName);
		}
		return null;
	}

最终通过databaseHandler去执行这个查询语句。

	public ResultSet executeStatement(String sql, boolean isQuery, Operator parameterHandler, Logger logger) throws SQLException, OperatorException {
		ResultSet resultSet = null;
		Statement statement;
		if (parameterHandler.getParameterAsBoolean(DatabaseHandler.PARAMETER_PREPARE_STATEMENT)) {
			PreparedStatement prepared = getConnection().prepareStatement(sql);
			String[] parameters = ParameterTypeEnumeration.transformString2Enumeration(parameterHandler.getParameterAsString(DatabaseHandler.PARAMETER_PARAMETERS));
			for (int i = 0; i < parameters.length; i++) {
				String[] argDescription = ParameterTypeTupel.transformString2Tupel(parameters[i]);
				final String sqlType = argDescription[0];
				final String replacementValue = argDescription[1];
				if ("VARCHAR".equals(sqlType)) {
					prepared.setString(i + 1, replacementValue);
				} else if ("REAL".equals(sqlType)) {
					try {
						prepared.setDouble(i + 1, Double.parseDouble(replacementValue));
					} catch (NumberFormatException e) {
						prepared.close();
						throw new UserError(parameterHandler, 158, replacementValue, sqlType);
					}
				} else if ("LONG".equals(sqlType)) {
					try {
						prepared.setLong(i + 1, Long.parseLong(replacementValue));
					} catch (NumberFormatException e) {
						prepared.close();
						throw new UserError(parameterHandler, 158, replacementValue, sqlType);
					}
				} else if ("INTEGER".equals(sqlType)) {
					try {
						prepared.setInt(i + 1, Integer.parseInt(replacementValue));
					} catch (NumberFormatException e) {
						prepared.close();
						throw new UserError(parameterHandler, 158, replacementValue, sqlType);
					}
				} else {
					prepared.close();
					throw new OperatorException("Illegal data type: " + sqlType);
				}
			}
			if (isQuery) {
				resultSet = prepared.executeQuery();
			} else {
				prepared.execute();
			}
			statement = prepared;
		} else {
			logger.info("Executing query: '" + sql + "'");
			statement = createStatement(false);
			if (isQuery) {
				resultSet = statement.executeQuery(sql);
			} else {
				statement.execute(sql);
			}
		}
		logger.fine("Query executed.");
		if (!isQuery) {
			statement.close();
		}
		return resultSet;
	}

1.4.2 getAttribute()

第二个方法getAttribute()就很简单了,即从返回的数据中拿到列名。

	private static List<Attribute> getAttributes(ResultSetMetaData metaData) throws SQLException {
		List<Attribute> result = new LinkedList<Attribute>();

		if (metaData != null) {
			// A map mapping original column names to a counter specifying how often
			// they were chosen
			Map<String, Integer> duplicateNameMap = new HashMap<String, Integer>();

			for (int columnIndex = 1; columnIndex <= metaData.getColumnCount(); columnIndex++) {

				// column name from DB
				String dbColumnName = metaData.getColumnLabel(columnIndex);

				// name that will be used in example set
				String columnName = dbColumnName;

				// check original name first
				Integer duplicateCount = duplicateNameMap.get(dbColumnName);
				boolean isUnique = duplicateCount == null;
				if (isUnique) {
					// name is unique
					duplicateNameMap.put(columnName, new Integer(1));
				} else {
					// name already present, iterate until unique
					while (!isUnique) {
						// increment duplicate counter
						duplicateCount = new Integer(duplicateCount.intValue() + 1);

						// create new name proposal
						columnName = dbColumnName + "_" + (duplicateCount - 1);  // -1 because of compatibility

						// check if new name is already taken
						isUnique = duplicateNameMap.get(columnName) == null;
					}

					// save new duplicate count for old db column name
					duplicateNameMap.put(dbColumnName, duplicateCount);
				}

				int attributeType = DatabaseHandler.getRapidMinerTypeIndex(metaData.getColumnType(columnIndex));
				final Attribute attribute = AttributeFactory.createAttribute(columnName, attributeType);
				attribute.getAnnotations().setAnnotation("sql_type", metaData.getColumnTypeName(columnIndex));
				result.add(attribute);
			}
		}

		return result;
	}

1.4.3 createExampleTable()

第三个方法createExampleTable().

	public static MemoryExampleTable createExampleTable(ResultSet resultSet, List<Attribute> attributes, int dataManagementType, Logger logger) throws SQLException, OperatorException {
		ResultSetMetaData metaData = resultSet.getMetaData();
		Attribute[] attributeArray = attributes.toArray(new Attribute[attributes.size()]);
		MemoryExampleTable table = new MemoryExampleTable(attributes);
		DataRowFactory factory = new DataRowFactory(dataManagementType, '.');
		while (resultSet.next()) {
			DataRow dataRow = factory.create(attributeArray.length);
			// double[] data = new double[attributeArray.length];
			for (int i = 1; i <= metaData.getColumnCount(); i++) {
				Attribute attribute = attributeArray[i - 1];
				int valueType = attribute.getValueType();
				double value;
				if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(valueType, Ontology.DATE_TIME)) {
					Timestamp timestamp = resultSet.getTimestamp(i);
					if (resultSet.wasNull()) {
						value = Double.NaN;
					} else {
						value = timestamp.getTime();
					}
				} else if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(valueType, Ontology.NUMERICAL)) {
					value = resultSet.getDouble(i);
					if (resultSet.wasNull()) {
						value = Double.NaN;
					}
				} else {
					if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(valueType, Ontology.NOMINAL)) {
						String valueString;
						if (metaData.getColumnType(i) == Types.CLOB) {
							Clob clob = resultSet.getClob(i);
							if (clob != null) {
								BufferedReader in = null;
								try {
									in = new BufferedReader(clob.getCharacterStream());
									String line = null;
									try {
										StringBuffer buffer = new StringBuffer();
										while ((line = in.readLine()) != null) {
											buffer.append(line + "\n");
										}
										valueString = buffer.toString();
									} catch (IOException e) {
										throw new OperatorException("Database error occurred: " + e, e);
									}
								} finally {
									try {
										in.close();
									} catch (IOException e) {}
								}
							} else {
								valueString = null;
							}
						} else {
							valueString = resultSet.getString(i);
						}
						if (resultSet.wasNull() || valueString == null) {
							value = Double.NaN;
						} else {
							value = attribute.getMapping().mapString(valueString);
						}
					} else {
						if (logger != null) {
							logger.warning("Unknown column type: " + attribute);
						}
						value = Double.NaN;
					}
				}
				dataRow.set(attribute, value);
				// data[i-1] = value;
			}
			table.addDataRow(dataRow); // new DoubleArrayDataRow(data));
		}
		return table;
	}

1.4.4 SimpleExampleSet()

最后一个方法,最终调用的是SimpleExampleSet的构造方法。即return new SimpleExampleSet(**)。

	public SimpleExampleSet(ExampleTable exampleTable, List<Attribute> regularAttributes, Map<Attribute, String> specialAttributes) {
		this.exampleTable = exampleTable;
		List<Attribute> regularList = regularAttributes;
		if (regularList == null) {
			regularList = new LinkedList<Attribute>();
			for (int a = 0; a < exampleTable.getNumberOfAttributes(); a++) {
				Attribute attribute = exampleTable.getAttribute(a);
				if (attribute != null)
					regularList.add(attribute);	
			}
		}
		
		for (Attribute attribute : regularList) {
			if ((specialAttributes == null) || (specialAttributes.get(attribute) == null))
				getAttributes().add(new AttributeRole((Attribute) attribute.clone()));
		}
		
		if (specialAttributes != null) {
			Iterator<Map.Entry<Attribute, String>> s = specialAttributes.entrySet().iterator();
			while (s.hasNext()) {
				Map.Entry<Attribute, String> entry = s.next();
				getAttributes().setSpecialAttribute((Attribute) entry.getKey().clone(), entry.getValue());
			}
		}
	}

猜你喜欢

转载自blog.csdn.net/liyuhui195134/article/details/81166834
今日推荐