Files
irrigation-model/irrgiation/db_connect.py
2025-12-23 08:38:08 +08:00

199 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import configparser
import os
from pathlib import Path
from typing import Optional, Union, List
import pandas as pd
import psycopg2
from dotenv import load_dotenv
from psycopg2 import sql
from psycopg2.extras import execute_batch
from sqlalchemy import create_engine, text
def get_config():
config = configparser.ConfigParser()
try:
project_nifi = Path(__file__).parent.parent / 'config.nifi'
config.read(project_nifi)
# 确保所有必要键都存在
required_keys = ['host', 'port', 'database', 'user', 'password','schema']
for key in required_keys:
if key not in config['postgresql']:
raise ValueError(f"Missing required config key: {key}")
return {
"host": config['postgresql']['host'],
"port": config['postgresql']['port'],
"database": config['postgresql']['database'],
"user": config['postgresql']['user'],
"password": config['postgresql']['password'],
"schema": config['postgresql']['schema'],
"tablename": config['postgresql']['tablename']
}
except Exception as e:
print(f"配置读取错误: {e}")
# 返回默认配置或退出
return None
def connect():
"""建立数据库连接"""
try:
config_params=get_config()
conn = psycopg2.connect(host=config_params['host'],port=config_params['port'],dbname=config_params['database'],
user=config_params['user'],password=config_params['password']
)
cursor = conn.cursor()
except Exception as e:
print(f"❌ 连接失败: {e}")
raise
# 根据条件查询
def query_postgresql_to_dataframe(
condition: Optional[dict] = None,
condition_operator: str = "AND",
columns: Union[str, List[str]] = "*",):
# 连接数据库
config_params = get_config()
conn = psycopg2.connect(host=config_params['host'], port=config_params['port'], dbname=config_params['database'],
user=config_params['user'], password=config_params['password']
)
try:
# 构建列选择部分
if isinstance(columns, list):
columns_sql = sql.SQL(", ").join(sql.Identifier(col) for col in columns)
else:
columns_sql = sql.SQL(columns)
# 构建基础查询
query = sql.SQL("SELECT {} FROM {}").format(
columns_sql,
sql.Identifier(config_params['tablename'])
)
# 构建条件部分(参数化)
params = {}
if condition:
conditions = []
for i, (key, value) in enumerate(condition.items()):
param_name = f"param_{i}"
conditions.append(sql.SQL("{} = %({})s").format(
sql.Identifier(key),
sql.SQL(param_name)
))
params[param_name] = value
where_clause = sql.SQL(" WHERE {}").format(
sql.SQL(f" {condition_operator} ").join(conditions)
)
query = query + where_clause
print(query)
# 执行查询
with conn.cursor() as cursor:
cursor.execute(query, params)
if cursor.description:
columns = [desc[0] for desc in cursor.description]
data = cursor.fetchall()
return pd.DataFrame(data, columns=columns)
return pd.DataFrame()
finally:
conn.close()
# 插入库表
def dataframe_to_postgresql_batch(df,batch_size=1000):
"""
使用execute_batch批量插入DataFrame数据
参数:
df: 要插入的DataFrame
table_name: 目标表名
config_params: 数据库连接配置
batch_size: 每批插入的行数
"""
if df.empty:
print("DataFrame为空无需插入")
return
# 获取列名
columns = df.columns.tolist()
data = [tuple(x) for x in df.to_numpy()]
try:
config_params = get_config()
conn = psycopg2.connect(host=config_params['host'], port=config_params['port'],
dbname=config_params['database'],
user=config_params['user'], password=config_params['password']
)
cursor = conn.cursor()
# 构建INSERT语句
insert_query = sql.SQL("INSERT INTO {} ({}) VALUES ({})").format(
sql.Identifier(config_params['tablename']),
sql.SQL(', ').join(map(sql.Identifier, columns)),
sql.SQL(', ').join([sql.Placeholder()] * len(columns))
)
# 批量执行
execute_batch(cursor, insert_query, data, batch_size)
conn.commit()
except Exception as e:
conn.rollback()
print(f"批量插入时出错: {e}")
finally:
if conn is not None:
conn.close()
# 根据条件更新某个字段
def update_irrigation_data(date_value, id_value, field_to_update, new_value):
"""
更新表中单条记录的单个字段
参数:
db_url: 数据库连接字符串
table_name: 表名
date_field: 日期字段名
id_field: ID/地块字段名
date_value: 日期值
id_value: ID/地块值
field_to_update: 要更新的字段名
new_value: 新值
"""
try:
# 数据库配置
db_config = {
"db_url": "postgresql://postgres:postgres@localhost:5432/datastore",
"table_name": "irrigation_data",
"date_field": "Date",
"id_field": "dkbm"
}
engine = create_engine(db_config["db_url"])
table_name=db_config["table_name"]
date_field=db_config["date_field"]
id_field=db_config["id_field"]
with engine.begin() as conn: # 自动提交事务
# 使用参数化查询防止SQL注入
update_sql = text(f"""
UPDATE {table_name}
SET "{field_to_update}" = :new_value
WHERE "{date_field}" = :date_value AND "{id_field}" = :id_value""")
conn.execute(update_sql, {
'new_value': new_value,
'date_value': date_value,
'id_value': id_value
})
return True
except Exception as e:
print(f"更新失败: {str(e)}")
return False
finally:
engine.dispose() # 确保连接关闭
if __name__ == '__main__':
config = configparser.ConfigParser()
project_nifi = Path(__file__).parent.parent / 'config.nifi'
config.read(project_nifi)