199 lines
6.7 KiB
Python
199 lines
6.7 KiB
Python
import configparser
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Optional, Union, List
|
||
|
||
import pandas as pd
|
||
import psycopg2
|
||
from dotenv import load_dotenv
|
||
from psycopg2 import sql
|
||
from psycopg2.extras import execute_batch
|
||
from sqlalchemy import create_engine, text
|
||
|
||
|
||
def get_config():
|
||
config = configparser.ConfigParser()
|
||
try:
|
||
project_nifi = Path(__file__).parent.parent / 'config.nifi'
|
||
config.read(project_nifi)
|
||
# 确保所有必要键都存在
|
||
required_keys = ['host', 'port', 'database', 'user', 'password','schema']
|
||
for key in required_keys:
|
||
if key not in config['postgresql']:
|
||
raise ValueError(f"Missing required config key: {key}")
|
||
|
||
return {
|
||
"host": config['postgresql']['host'],
|
||
"port": config['postgresql']['port'],
|
||
"database": config['postgresql']['database'],
|
||
"user": config['postgresql']['user'],
|
||
"password": config['postgresql']['password'],
|
||
"schema": config['postgresql']['schema'],
|
||
"tablename": config['postgresql']['tablename']
|
||
}
|
||
except Exception as e:
|
||
print(f"配置读取错误: {e}")
|
||
# 返回默认配置或退出
|
||
return None
|
||
def connect():
|
||
"""建立数据库连接"""
|
||
try:
|
||
config_params=get_config()
|
||
conn = psycopg2.connect(host=config_params['host'],port=config_params['port'],dbname=config_params['database'],
|
||
user=config_params['user'],password=config_params['password']
|
||
)
|
||
cursor = conn.cursor()
|
||
except Exception as e:
|
||
print(f"❌ 连接失败: {e}")
|
||
raise
|
||
|
||
# 根据条件查询
|
||
def query_postgresql_to_dataframe(
|
||
condition: Optional[dict] = None,
|
||
condition_operator: str = "AND",
|
||
columns: Union[str, List[str]] = "*",):
|
||
|
||
# 连接数据库
|
||
config_params = get_config()
|
||
conn = psycopg2.connect(host=config_params['host'], port=config_params['port'], dbname=config_params['database'],
|
||
user=config_params['user'], password=config_params['password']
|
||
)
|
||
try:
|
||
# 构建列选择部分
|
||
if isinstance(columns, list):
|
||
columns_sql = sql.SQL(", ").join(sql.Identifier(col) for col in columns)
|
||
else:
|
||
columns_sql = sql.SQL(columns)
|
||
|
||
# 构建基础查询
|
||
query = sql.SQL("SELECT {} FROM {}").format(
|
||
columns_sql,
|
||
sql.Identifier(config_params['tablename'])
|
||
)
|
||
|
||
# 构建条件部分(参数化)
|
||
params = {}
|
||
if condition:
|
||
conditions = []
|
||
for i, (key, value) in enumerate(condition.items()):
|
||
param_name = f"param_{i}"
|
||
conditions.append(sql.SQL("{} = %({})s").format(
|
||
sql.Identifier(key),
|
||
sql.SQL(param_name)
|
||
))
|
||
params[param_name] = value
|
||
|
||
where_clause = sql.SQL(" WHERE {}").format(
|
||
sql.SQL(f" {condition_operator} ").join(conditions)
|
||
)
|
||
query = query + where_clause
|
||
print(query)
|
||
# 执行查询
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(query, params)
|
||
if cursor.description:
|
||
columns = [desc[0] for desc in cursor.description]
|
||
data = cursor.fetchall()
|
||
return pd.DataFrame(data, columns=columns)
|
||
return pd.DataFrame()
|
||
|
||
finally:
|
||
conn.close()
|
||
|
||
# 插入库表
|
||
def dataframe_to_postgresql_batch(df,batch_size=1000):
|
||
"""
|
||
使用execute_batch批量插入DataFrame数据
|
||
|
||
参数:
|
||
df: 要插入的DataFrame
|
||
table_name: 目标表名
|
||
config_params: 数据库连接配置
|
||
batch_size: 每批插入的行数
|
||
"""
|
||
if df.empty:
|
||
print("DataFrame为空,无需插入")
|
||
return
|
||
|
||
# 获取列名
|
||
columns = df.columns.tolist()
|
||
data = [tuple(x) for x in df.to_numpy()]
|
||
|
||
try:
|
||
config_params = get_config()
|
||
conn = psycopg2.connect(host=config_params['host'], port=config_params['port'],
|
||
dbname=config_params['database'],
|
||
user=config_params['user'], password=config_params['password']
|
||
)
|
||
|
||
cursor = conn.cursor()
|
||
|
||
# 构建INSERT语句
|
||
insert_query = sql.SQL("INSERT INTO {} ({}) VALUES ({})").format(
|
||
sql.Identifier(config_params['tablename']),
|
||
sql.SQL(', ').join(map(sql.Identifier, columns)),
|
||
sql.SQL(', ').join([sql.Placeholder()] * len(columns))
|
||
)
|
||
|
||
# 批量执行
|
||
execute_batch(cursor, insert_query, data, batch_size)
|
||
|
||
conn.commit()
|
||
|
||
except Exception as e:
|
||
conn.rollback()
|
||
print(f"批量插入时出错: {e}")
|
||
finally:
|
||
if conn is not None:
|
||
conn.close()
|
||
# 根据条件更新某个字段
|
||
def update_irrigation_data(date_value, id_value, field_to_update, new_value):
|
||
"""
|
||
更新表中单条记录的单个字段
|
||
|
||
参数:
|
||
db_url: 数据库连接字符串
|
||
table_name: 表名
|
||
date_field: 日期字段名
|
||
id_field: ID/地块字段名
|
||
date_value: 日期值
|
||
id_value: ID/地块值
|
||
field_to_update: 要更新的字段名
|
||
new_value: 新值
|
||
"""
|
||
try:
|
||
# 数据库配置
|
||
db_config = {
|
||
"db_url": "postgresql://postgres:postgres@localhost:5432/datastore",
|
||
"table_name": "irrigation_data",
|
||
"date_field": "Date",
|
||
"id_field": "dkbm"
|
||
}
|
||
engine = create_engine(db_config["db_url"])
|
||
table_name=db_config["table_name"]
|
||
date_field=db_config["date_field"]
|
||
id_field=db_config["id_field"]
|
||
with engine.begin() as conn: # 自动提交事务
|
||
# 使用参数化查询防止SQL注入
|
||
update_sql = text(f"""
|
||
UPDATE {table_name}
|
||
SET "{field_to_update}" = :new_value
|
||
WHERE "{date_field}" = :date_value AND "{id_field}" = :id_value""")
|
||
conn.execute(update_sql, {
|
||
'new_value': new_value,
|
||
'date_value': date_value,
|
||
'id_value': id_value
|
||
})
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"更新失败: {str(e)}")
|
||
return False
|
||
finally:
|
||
engine.dispose() # 确保连接关闭
|
||
|
||
|
||
if __name__ == '__main__':
|
||
config = configparser.ConfigParser()
|
||
project_nifi = Path(__file__).parent.parent / 'config.nifi'
|
||
config.read(project_nifi) |