diff --git a/AGENTS.md b/AGENTS.md index 9ba5d35..f789b85 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -24,9 +24,9 @@ When writing code or configurations for this project, you **MUST** strictly adhe ### A. Python Coding & Naming Standards -- **Primary Type/First-Class Identity Prefix**: Place the variable's type, role, or primary characteristics first in its name. - - _Correct_: `name_service`, `port_service`, `svc_ingress`, `cfg_postgres`, `db_mysql`. - - _Incorrect_: `service_name`, `service_port`, `ingress_service`, `postgres_config`, `mysql_db`. +- **Primary Type/First-Class Identity Prefix**: Place the variable's type, role, or primary characteristics first in its name. If multiple variables in a segment of code belong to the same category, type, or semantic group, place the common semantic prefix first. + - _Correct_: `name_service`, `port_service`, `svc_ingress`, `cfg_postgres`, `db_mysql`, `msg_err`, `msg_info`. + - _Incorrect_: `service_name`, `service_port`, `ingress_service`, `postgres_config`, `mysql_db`, `err_msg`, `info_msg`. - **Logger Naming**: Use lowercase with underscores for logger names, e.g., `db_sync`, `api_router`. - **Import Conventions**: Use relative imports if possible, especially inside a package. diff --git a/doc/skills/aloha_python/SKILL.md b/doc/skills/aloha_python/SKILL.md index bf5f95c..be31088 100644 --- a/doc/skills/aloha_python/SKILL.md +++ b/doc/skills/aloha_python/SKILL.md @@ -13,9 +13,9 @@ This skill provides coding standards, modular application structures, and usage When developing Python code in this codebase, adhere to the following naming conventions: -- **Primary Type/First-Class Identity Prefix**: Place the variable's type, role, or primary characteristics first in its name. - - _Correct_: `name_service`, `port_service`, `svc_ingress`, `cfg_postgres`. - - _Incorrect_: `service_name`, `service_port`, `ingress_service`, `postgres_config`. +- **Primary Type/First-Class Identity Prefix**: Place the variable's type, role, or primary characteristics first in its name. If multiple variables in a segment of code belong to the same category, type, or semantic group, place the common semantic prefix first. + - _Correct_: `name_service`, `port_service`, `svc_ingress`, `cfg_postgres`, `db_mysql`, `msg_err`, `msg_info`. + - _Incorrect_: `service_name`, `service_port`, `ingress_service`, `postgres_config`, `mysql_db`, `err_msg`, `info_msg`. - **Logger naming**: Use lowercase with underscores for logger names, e.g., `db_sync`. --- diff --git a/pkg/aloha/db/__init__.py b/pkg/aloha/db/__init__.py index e69de29..e41df7c 100644 --- a/pkg/aloha/db/__init__.py +++ b/pkg/aloha/db/__init__.py @@ -0,0 +1,176 @@ +""" +aloha.db package - Database and middleware connection helpers. + +Sync modules (blocking): + from aloha.db import PostgresOperator, MySqlOperator, RedisOperator, ... + from aloha.db.postgres import PostgresOperator + from aloha.db.mysql import MySqlOperator + from aloha.db.redis import RedisOperator + from aloha.db.mongo import MongoOperator + from aloha.db.elasticsearch import ElasticSearchOperator + from aloha.db.kafka import KafkaOperator + from aloha.db.sqlite import SqliteOperator + from aloha.db.duckdb import DuckOperator + from aloha.db.oracle import OracledbOperator + +Async modules (non-blocking): + from aloha.db import PostgresOperator as PostgresOperatorAio, ... + from aloha.db.postgres_aio import PostgresOperator + from aloha.db.mysql_aio import MySqlOperator + from aloha.db.redis_aio import RedisOperator + from aloha.db.mongo_aio import MongoOperator + from aloha.db.elasticsearch_aio import ElasticSearchOperator + from aloha.db.kafka_aio import KafkaOperator + from aloha.db.sqlite_aio import SqliteOperator + from aloha.db.duckdb_aio import DuckOperator + from aloha.db.oracle_aio import OracledbOperator + +Base utilities: + from aloha.db.base import PasswordVault + from aloha.db.base_aio import PasswordVault # async version + +Usage example (sync): + from aloha.db.postgres import PostgresOperator + + op = PostgresOperator(db_config) + result = op.execute_query("SELECT * FROM users") + for row in result: + print(row) + +Usage example (async): + from aloha.db.postgres_aio import PostgresOperator + + async def main(): + op = PostgresOperator(db_config) + result = await op.execute_query("SELECT * FROM users") + async for row in op.execute_query_scalars("SELECT * FROM users"): + print(row) + await op.close() + + import asyncio + asyncio.run(main()) +""" + +# Sync modules +from .base import PasswordVault + +try: + from .postgres import PostgresOperator +except (ImportError, ModuleNotFoundError): + pass + +try: + from .mysql import MySqlOperator +except (ImportError, ModuleNotFoundError): + pass + +try: + from .redis import RedisOperator +except (ImportError, ModuleNotFoundError): + pass + +try: + from .mongo import MongoOperator +except (ImportError, ModuleNotFoundError): + pass + +try: + from .elasticsearch import ElasticSearchOperator +except (ImportError, ModuleNotFoundError): + pass + +try: + from .kafka import KafkaOperator, ConsumedMessage +except (ImportError, ModuleNotFoundError): + pass + +try: + from .sqlite import SqliteOperator +except (ImportError, ModuleNotFoundError): + pass + +try: + from .duckdb import DuckOperator +except (ImportError, ModuleNotFoundError): + pass + +try: + from .oracle import OracledbOperator +except (ImportError, ModuleNotFoundError): + pass + + +# Async modules (importable as aliases for easy switching) +from .base_aio import PasswordVault as PasswordVaultAio + +try: + from .postgres_aio import PostgresOperator as PostgresOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .mysql_aio import MySqlOperator as MySqlOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .redis_aio import RedisOperator as RedisOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .mongo_aio import MongoOperator as MongoOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .elasticsearch_aio import ElasticSearchOperator as ElasticSearchOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .kafka_aio import KafkaOperator as KafkaOperatorAio, ConsumedMessage as ConsumedMessageAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .sqlite_aio import SqliteOperator as SqliteOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .duckdb_aio import DuckOperator as DuckOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +try: + from .oracle_aio import OracledbOperator as OracledbOperatorAio +except (ImportError, ModuleNotFoundError): + pass + +__all__ = ( + # Sync operators + "PostgresOperator", + "MySqlOperator", + "RedisOperator", + "MongoOperator", + "ElasticSearchOperator", + "KafkaOperator", + "ConsumedMessage", + "SqliteOperator", + "DuckOperator", + "OracledbOperator", + "PasswordVault", + # Async operators (aliased) + "PostgresOperatorAio", + "MySqlOperatorAio", + "RedisOperatorAio", + "MongoOperatorAio", + "ElasticSearchOperatorAio", + "KafkaOperatorAio", + "ConsumedMessageAio", + "SqliteOperatorAio", + "DuckOperatorAio", + "OracledbOperatorAio", + "PasswordVaultAio", +) \ No newline at end of file diff --git a/pkg/aloha/db/base_aio.py b/pkg/aloha/db/base_aio.py new file mode 100644 index 0000000..002a9f6 --- /dev/null +++ b/pkg/aloha/db/base_aio.py @@ -0,0 +1,85 @@ +""" +Async password vault manager for async database operations. +""" + +from ..encrypt import vault +from ..logger import LOG +from ..settings import SETTINGS + + +class PasswordVault: + """ + Async password vault manager that provides access to password vault implementations. + + Caches vault instances for performance. + """ + + _dict_cache_vault = {} + + @staticmethod + async def get_vault(vault_type: str | None = None, vault_config: dict | None = None, **kwargs) -> vault.BaseVault: + """ + Get a password vault instance (async version). + + Supports multiple vault types: + - 'plain' or 'aes': AES-based vault (default fallback) + - 'cyberark': CyberArk vault + - Other/None: Dummy vault (plain text) + + :param vault_type: Type of vault to use (overrides config) + :param vault_config: Vault configuration dictionary + :param args: Additional arguments + :param kwargs: Additional keyword arguments + :return: Vault instance implementing BaseVault interface + :raises RuntimeError: If CyberArk vault is requested but config is missing + """ + encryption_method = vault_type or SETTINGS.config.get("PASSWORD_ENCRYPTION") + LOG.debug("Using password vault (async): %s", encryption_method) + + cache_key = "%s:%s" % (encryption_method, str(vault_config)) + if cache_key not in PasswordVault._dict_cache_vault: + if encryption_method in ("plain", "aes") or encryption_method is True: + v = vault.AesVault(**(vault_config or {})) + elif encryption_method == "cyberark": + config_cyberark = vault_config or SETTINGS.config.get("CYBERARK_CONFIG") + if config_cyberark is None: + raise RuntimeError("Missing [CYBERARK_CONFIG] in config!") + v = vault.CyberArkVault(**config_cyberark) + else: + msg = "Using plain password vault as unknown value of PASSWORD_ENCRYPTION=%s in config." % encryption_method + LOG.info(msg) + v = vault.DummyVault(**(vault_config or {})) + PasswordVault._dict_cache_vault[cache_key] = v + + return PasswordVault._dict_cache_vault[cache_key] + + @staticmethod + def get_vault_sync(vault_type: str | None = None, vault_config: dict | None = None, **kwargs) -> vault.BaseVault: + """ + Get a password vault instance (sync version for backward compatibility). + + :param vault_type: Type of vault to use (overrides config) + :param vault_config: Vault configuration dictionary + :param args: Additional arguments + :param kwargs: Additional keyword arguments + :return: Vault instance implementing BaseVault interface + """ + encryption_method = vault_type or SETTINGS.config.get("PASSWORD_ENCRYPTION") + LOG.debug("Using password vault (sync): %s", encryption_method) + + cache_key = "%s:%s" % (encryption_method, str(vault_config)) + if cache_key not in PasswordVault._dict_cache_vault: + if encryption_method in ("plain", "aes") or encryption_method is True: + v = vault.AesVault(**(vault_config or {})) + elif encryption_method == "cyberark": + config_cyberark = vault_config or SETTINGS.config.get("CYBERARK_CONFIG") + if config_cyberark is None: + raise RuntimeError("Missing [CYBERARK_CONFIG] in config!") + v = vault.CyberArkVault(**config_cyberark) + else: + msg = "Using plain password vault as unknown value of PASSWORD_ENCRYPTION=%s in config." % encryption_method + LOG.info(msg) + v = vault.DummyVault(**(vault_config or {})) + PasswordVault._dict_cache_vault[cache_key] = v + + return PasswordVault._dict_cache_vault[cache_key] \ No newline at end of file diff --git a/pkg/aloha/db/duckdb_aio.py b/pkg/aloha/db/duckdb_aio.py new file mode 100644 index 0000000..c3cf2b5 --- /dev/null +++ b/pkg/aloha/db/duckdb_aio.py @@ -0,0 +1,126 @@ +""" +Async DuckDB connection helpers. +""" + +from pathlib import Path + +import duckdb +import duckdb_engine +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy import text + +from aloha.logger import LOG + +__all__ = ("DuckOperator",) + +LOG.debug("duckdb_aio version = %s, duckdb_engine = %s (async)", (duckdb.__version__, duckdb_engine.__version__)) + + +class DuckOperator: + """Create and use an async DuckDB connection through SQLAlchemy.""" + + def __init__(self, db_config, **kwargs): + """Build an async DuckDB engine, creating the database file if necessary.""" + """db_config example: + { + "path": "/path/to/db.duckdb", # file path of duckdb, use ":memory:" for in-memory mode + "schema": "sales", # optional, 'main' by default + "read_only": True, # optional, False by default, (will set to False if in in-memory mode) + "config": {"memory_limit": "500mb"}, # optional, duckdb connection configs + } + """ + self._config = { + "path": db_config.get("path", ":memory:"), + "schema": db_config.get("schema", "main"), + "read_only": bool(db_config.get("read_only", False)), + "config": db_config.get("config", {}), + "auto_commit": db_config.get("auto_commit", True), + } + + if not self._config["path"] or self._config["path"] == ":memory:": + self._config["path"] = ":memory:" + + if self._config["read_only"]: + LOG.warning("In-memory database cannot be read-only. Setting read_only=False.") + self._config["read_only"] = False + + else: + self._prepare_database() + + try: + str_connection = f"duckdb+aioduckdb:///{self._config['path']}" + self.engine: AsyncEngine = create_async_engine( + str_connection, + connect_args={"read_only": self._config["read_only"], "config": self._config["config"]}, + **kwargs, + ) + + LOG.debug("DuckDB (async) connected: {path} [schema={schema}, read_only={read_only}]".format(**self._config)) + except Exception as e: + LOG.exception(e) + raise RuntimeError("Failed to connect to DuckDB (async)") + + def _prepare_database(self): + """Prepare the database file and its parent directory.""" + path = self._config["path"] + path_obj = Path(path) + + parent_dir = path_obj.parent + if not parent_dir.exists(): + if self._config["read_only"]: + raise RuntimeError(f"Directory '{parent_dir}' does not exist and read_only=True") + try: + parent_dir.mkdir(parents=True, exist_ok=True) + LOG.debug(f"Created directory: {parent_dir}") + except Exception as e: + raise RuntimeError(f"Failed to create directory '{parent_dir}': {e}") + + if not path_obj.exists(): + if self._config["read_only"]: + raise RuntimeError(f"DuckDB file '{path}' does not exist and read_only=True") + try: + LOG.debug(f"Database file not found, creating: {path}") + duckdb.connect(path).close() + except Exception as e: + raise RuntimeError(f"Failed to create database file '{path}': {e}") + + @property + def connection(self): + return self.engine + + @property + def conn(self): + """Alias for connection property.""" + return self.engine + + async def execute_query(self, sql, *args, **kwargs): + """Execute a SQL statement asynchronously and return the cursor result.""" + async with self.engine.connect() as conn: + cur = await conn.execute(text(sql), *args, **kwargs) + if self._config.get("auto_commit", True): + await conn.commit() + return cur + + async def execute_query_scalars(self, sql, *args, **kwargs): + """Execute a SQL statement and return all scalar results.""" + async with self.engine.connect() as conn: + cur = await conn.stream_scalars(text(sql), *args, **kwargs) + async for row in cur: + yield row + + @property + def connection_str(self) -> str: + """Return a human-readable connection string.""" + return f"duckdb:///{self._config['path']} [schema={self._config['schema']}, read_only={self._config['read_only']}] (async)" + + async def close(self): + """Close the async engine and all connections.""" + await self.engine.dispose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() \ No newline at end of file diff --git a/pkg/aloha/db/elasticsearch_aio.py b/pkg/aloha/db/elasticsearch_aio.py new file mode 100644 index 0000000..76e7e3b --- /dev/null +++ b/pkg/aloha/db/elasticsearch_aio.py @@ -0,0 +1,123 @@ +""" +Async Elasticsearch connection helpers. +""" + +import json +import re + +from elasticsearch import AsyncElasticsearch + +from ..logger import LOG +from .base_aio import PasswordVault + +__all__ = ("ElasticSearchOperator",) + + +def _mask_hosts(hosts): + if isinstance(hosts, list): + return [_mask_hosts(h) for h in hosts] + if isinstance(hosts, dict): + return {k: ("***" if k in ("password", "http_auth") else _mask_hosts(v)) for k, v in hosts.items()} + if isinstance(hosts, str): + return re.sub(r"([^:/]+://)?([^:/]+):([^@]+)@", r"\1\2:***@", hosts) + return hosts + + +class ElasticSearchOperator: + """Create and use an async Elasticsearch client with optional index helpers.""" + + def __init__(self, config, index_config=None): + """Build the async client and optionally load the index configuration.""" + self.es_config = config + + password_vault = PasswordVault.get_vault_sync(config.get("vault_type"), config.get("vault_config")) + username = config.get("username") + password = password_vault.get_password(config.get("password")) + + hosts = config.get("host", "localhost") + masked_hosts = _mask_hosts(hosts) + LOG.debug("ElasticSearch (async) connection info: " + str(masked_hosts)) + + self._config = { + "http_auth": (username, password) if username is not None and password is not None else None, + "hosts": hosts, + "timeout": config.get("timeout", 0.1), + "max_retries": config.get("max_retries", 3), + "retry_on_timeout": config.get("retry_on_timeout", True), + } + + self.index_config = index_config + self.index_name = self.es_config.get("index_name") + self.index_type = self.es_config.get("index_type") + + self.es: AsyncElasticsearch = AsyncElasticsearch(**self._config) + + if index_config is not None: + self.index_config = self._load_config(index_config) + + @staticmethod + def _load_config(config): + """Load an index configuration from a dict or JSON file.""" + if isinstance(config, dict): + return config + + elif isinstance(config, str) and ".json" in config: + with open(config, "r", encoding="utf-8") as f: + config = json.load(f) + return config + + else: + raise ValueError("Invalid ES config data type") + + async def put_mapping(self, index_name=None, index_type=None, index_config: dict | None = None): + """Apply a mapping definition to the current index asynchronously.""" + return await self.es.indices.put_mapping( + index=index_name or self.index_name, + doc_type=index_type or self.index_type, + body=index_config["mappings"][index_type or self.index_type], + ) + + async def build_index(self, index_name=None, index_config=None, raise_if_exist=False): + """Create the index if it does not already exist asynchronously.""" + if not await self.es.indices.exists(index=index_name or self.index_name): + res = await self.es.indices.create(index=index_name or self.index_name, body=index_config or self.index_config) + return res + else: + msg = "Index [%s] already exits" % self.index_name + if raise_if_exist: + raise RuntimeError(msg) + else: + LOG.info(msg) + return False + + async def search(self, query, index_name=None, index_type=None): + """Execute a search query asynchronously.""" + return await self.es.search(index=index_name or self.index_name, doc_type=index_type or self.index_type, body=query) + + async def msearch(self, body): + """Execute a multi-search request asynchronously.""" + return await self.es.msearch(body=body) + + async def insert(self, doc, index_name=None, index_type=None, id=None): + """Insert or replace a document asynchronously.""" + return await self.es.index(index=index_name or self.index_name, doc_type=index_type or self.index_type, id=id, body=doc) + + async def delete(self, index_name=None, index_type=None, id=None): + """Delete a document by ID asynchronously.""" + return await self.es.delete(index=index_name or self.index_name, doc_type=index_type or self.index_type, id=id) + + async def exists(self, index_name=None): + """Check if an index exists asynchronously.""" + return await self.es.indices.exists(index=index_name or self.index_name) + + async def close(self): + """Close the Elasticsearch client.""" + await self.es.close() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() \ No newline at end of file diff --git a/pkg/aloha/db/kafka.py b/pkg/aloha/db/kafka.py index 0817d4f..b8f1d12 100644 --- a/pkg/aloha/db/kafka.py +++ b/pkg/aloha/db/kafka.py @@ -2,17 +2,62 @@ import json import typing +from dataclasses import dataclass import confluent_kafka as kafka import confluent_kafka.admin as kafka_admin from ..logger import LOG -__all__ = ("KafkaOperator",) +__all__ = ("KafkaOperator", "ConsumedMessage") LOG.debug("Version of confluent_kafka client = %s" % kafka.__version__) +@dataclass +class ConsumedMessage: + """Represents a message consumed from Kafka.""" + + topic: str + partition: int + offset: int + key: str | bytes | None + value: str | bytes | None + headers: list[tuple[str, str | bytes]] | None = None + + +def _unpack_message(data: typing.Any) -> typing.Tuple[typing.Any, typing.Any, typing.Any]: + """Unpack data into (value, key, headers).""" + if isinstance(data, dict): + value = data.get("value") + key = data.get("key") + headers = data.get("headers") + elif hasattr(data, "value"): + value = data.value + key = getattr(data, "key", None) + headers = getattr(data, "headers", None) + else: + value = data + key = None + headers = None + return value, key, headers + + +def _prepare_headers(headers: typing.Any) -> typing.List[typing.Tuple[str, bytes]] | None: + if not headers: + return None + if isinstance(headers, dict): + headers = list(headers.items()) + processed = [] + for k, v in headers: + if isinstance(v, str): + v = v.encode("utf-8") + elif v is None: + v = b"" + processed.append((k, v)) + return processed + + class KafkaOperator: """Create Kafka admin, producer, and consumer clients.""" @@ -32,6 +77,8 @@ def __init__(self, kafka_config): "bootstrap.servers": ",".join(["{host}:{port}".format(**i) for i in kafka_config.pop("host")]), } LOG.debug("Kafka connection info: " + str(self._config)) + self._producer = None + self._consumer = None def admin_client(self, *args, **kwargs): """Return a configured Kafka AdminClient.""" @@ -57,15 +104,24 @@ def create_topic(self, topic: str, num_partitions=3, replication_factor=1, *args LOG.error("Failed to create topic {}: {}".format(topic, e)) return False finally: - a.close() + if hasattr(a, "close"): + a.close() return True - def producer_deliver(self, topic: str, generator: typing.Iterator[str], func_callback: callable = None, *args, **kwargs): + def producer(self) -> kafka.Producer: + """Return a configured Kafka Producer.""" + if self._producer is None: + config_producer = {**self._config} + self._producer = kafka.Producer(config_producer) + return self._producer + + def producer_deliver( + self, topic: str, generator: typing.Iterator[typing.Any], func_callback: callable = None, *args, **kwargs + ): """Stream messages from an iterator into a Kafka topic.""" # func_callback should be a function that takes two arguments: err and msg - config_producer = {**self._config} - p = kafka.Producer(config_producer) + p = self.producer() def delivery_report(err, msg): """Called once for each message produced to indicate delivery result. Triggered by poll() or flush().""" @@ -81,38 +137,107 @@ def delivery_report(err, msg): # Trigger any available delivery report callbacks from previous produce() calls p.poll(0) + value, key, headers = _unpack_message(data) + if isinstance(value, str): + value = value.encode("utf-8") + prepared_headers = _prepare_headers(headers) + # Asynchronously produce a message, the delivery report callback # will be triggered from poll() above, or flush() below, when the message has # been successfully delivered or failed permanently. - p.produce(topic, data.encode("utf-8"), callback=func_callback) + p.produce(topic, value=value, key=key, headers=prepared_headers, callback=func_callback) # Wait for any outstanding messages to be delivered and delivery report callbacks to be triggered. p.flush() def consumer_generator( self, topics_subscribe: list, group_id: str | None = None, poll_timeout: float = 1.0, *args, **kwargs - ) -> typing.Iterator[str]: + ) -> typing.Iterator[ConsumedMessage]: """Yield decoded messages from the subscribed Kafka topics.""" - config_consumer = {"auto.offset.reset": "earliest", **self._config} + config_consumer = {"auto.offset.reset": "earliest", "enable.auto.commit": False, **self._config} if group_id is not None: config_consumer["group.id"] = group_id + + # Merge extra config passed via kwargs (convert snake_case to dot.case for confluent_kafka) + for k, v in kwargs.items(): + k_dot = k.replace("_", ".") + config_consumer[k_dot] = v + c = kafka.Consumer(config_consumer) + self._consumer = c c.subscribe(topics_subscribe) - while True: - msg = c.poll(poll_timeout) - - if msg is None: - continue - elif msg.error(): - code = msg.error().code() - if code == kafka.KafkaError._PARTITION_EOF: - pass - LOG.error("Kafka consumer: {}".format(msg.error())) - continue - - data = msg.value().decode("utf-8") - LOG.debug("Received message: {}".format(data)) - yield data - - c.close() + try: + while True: + msg = c.poll(poll_timeout) + + if msg is None: + continue + elif msg.error(): + code = msg.error().code() + if code == kafka.KafkaError._PARTITION_EOF: + pass + LOG.error("Kafka consumer: {}".format(msg.error())) + continue + + val = msg.value() + if val is not None: + try: + val = val.decode("utf-8") + except Exception: + pass + + key = msg.key() + if key is not None: + try: + key = key.decode("utf-8") + except Exception: + pass + + raw_headers = msg.headers() + headers = None + if raw_headers: + headers = [] + for k, v in raw_headers: + if isinstance(v, bytes): + try: + v = v.decode("utf-8") + except Exception: + pass + headers.append((k, v)) + + consumed_msg = ConsumedMessage( + topic=msg.topic(), + partition=msg.partition(), + offset=msg.offset(), + key=key, + value=val, + headers=headers, + ) + LOG.debug("Received message: {}".format(consumed_msg)) + yield consumed_msg + finally: + c.close() + self._consumer = None + + def commit(self, *args, **kwargs): + """Commit offsets for the current consumer.""" + if self._consumer is not None: + self._consumer.commit(*args, **kwargs) + + def close(self): + """Close all Kafka clients.""" + if self._producer is not None: + self._producer.flush() + self._producer = None + if self._consumer is not None: + self._consumer.close() + self._consumer = None + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() diff --git a/pkg/aloha/db/kafka_aio.py b/pkg/aloha/db/kafka_aio.py new file mode 100644 index 0000000..86e6de8 --- /dev/null +++ b/pkg/aloha/db/kafka_aio.py @@ -0,0 +1,238 @@ +"""Async Kafka connection helpers.""" + +import inspect +import json +import typing +from dataclasses import dataclass + +import aiokafka as kafka +import aiokafka.admin as kafka_admin + +from ..logger import LOG + +__all__ = ("KafkaOperator", "ConsumedMessage") + +LOG.debug("kafka_aio: using aiokafka for async Kafka support") + + +@dataclass +class ConsumedMessage: + """Represents a message consumed from Kafka.""" + + topic: str + partition: int + offset: int + key: str | bytes | None + value: str | bytes | None + headers: list[tuple[str, str | bytes]] | None = None + + +class DummyMessage: + def __init__(self, topic: str, partition: int): + self._topic = topic + self._partition = partition + + def topic(self) -> str: + return self._topic + + def partition(self) -> int: + return self._partition + + +def _unpack_message(data: typing.Any) -> typing.Tuple[typing.Any, typing.Any, typing.Any]: + """Unpack data into (value, key, headers).""" + if isinstance(data, dict): + value = data.get("value") + key = data.get("key") + headers = data.get("headers") + elif hasattr(data, "value"): + value = data.value + key = getattr(data, "key", None) + headers = getattr(data, "headers", None) + else: + value = data + key = None + headers = None + return value, key, headers + + +def _prepare_headers(headers: typing.Any) -> typing.List[typing.Tuple[str, bytes]] | None: + if not headers: + return None + if isinstance(headers, dict): + headers = list(headers.items()) + processed = [] + for k, v in headers: + if isinstance(v, str): + v = v.encode("utf-8") + elif v is None: + v = b"" + processed.append((k, v)) + return processed + + +class KafkaOperator: + """Create async Kafka admin, producer, and consumer clients.""" + + def __init__(self, kafka_config): + """ + Parameter reference: https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md + + :param kafka_config: + host = [ + {host: kafka_server_1, port: 9092} + ] + """ + self._config = json.loads(json.dumps(kafka_config, ensure_ascii=False)) + + if "host" in kafka_config: + self._config = { + "bootstrap_servers": ",".join(["{host}:{port}".format(**i) for i in kafka_config.pop("host")]), + } + LOG.debug("Kafka (async) connection info: " + str(self._config)) + + self._admin_client = None + self._producer = None + self._consumer = None + + async def admin_client(self, *args, **kwargs) -> kafka_admin.AIOKafkaAdminClient: + """Return a configured async Kafka AdminClient.""" + if self._admin_client is None: + self._admin_client = kafka_admin.AIOKafkaAdminClient(bootstrap_servers=self._config.get("bootstrap_servers")) + await self._admin_client.start() + return self._admin_client + + async def create_topic(self, topic: str, num_partitions=3, replication_factor=1, *args, **kwargs): + """Create a Kafka topic and wait for the broker response asynchronously.""" + admin = await self.admin_client() + try: + new_topic = kafka_admin.NewTopic(topic, num_partitions=num_partitions, replication_factor=replication_factor) + await admin.create_topics([new_topic]) + LOG.info("Topic {} created".format(topic)) + return True + except Exception as e: + LOG.error("Failed to create topic {}: {}".format(topic, e)) + return False + + async def producer(self) -> kafka.AIOKafkaProducer: + """Return a configured async Kafka Producer.""" + if self._producer is None: + self._producer = kafka.AIOKafkaProducer( + bootstrap_servers=self._config.get("bootstrap_servers"), + value_serializer=lambda v: v.encode("utf-8") if isinstance(v, str) else v, + key_serializer=lambda k: k.encode("utf-8") if isinstance(k, str) else k, + ) + await self._producer.start() + return self._producer + + async def producer_deliver( + self, topic: str, generator: typing.AsyncIterator[typing.Any], func_callback=None, *args, **kwargs + ): + """Stream messages from an async iterator into a Kafka topic.""" + producer = await self.producer() + + if func_callback is None: + + async def delivery_report(err, msg): + """Called once for each message produced to indicate delivery result.""" + if err is not None: + LOG.error("Kafka msg delivery failed: {}".format(err)) + else: + LOG.debug("Kafka msg delivered to {} [{}]".format(msg.topic(), msg.partition())) + + func_callback = delivery_report + + async for data in generator: + value, key, headers = _unpack_message(data) + prepared_headers = _prepare_headers(headers) + try: + metadata = await producer.send_and_wait( + topic, + value=value, + key=key, + headers=prepared_headers, + ) + if func_callback is not None: + if inspect.iscoroutinefunction(func_callback): + await func_callback(None, DummyMessage(topic, metadata.partition)) + else: + func_callback(None, DummyMessage(topic, metadata.partition)) + except Exception as e: + if func_callback is not None: + if inspect.iscoroutinefunction(func_callback): + await func_callback(e, None) + else: + func_callback(e, None) + else: + LOG.error("Kafka msg delivery failed: {}".format(e)) + + async def consumer_generator( + self, topics_subscribe: list, group_id: str | None = None, poll_timeout: float = 1.0, *args, **kwargs + ) -> typing.AsyncIterator[ConsumedMessage]: + """Yield decoded messages from the subscribed Kafka topics asynchronously.""" + # Enable manual commit by default (At-least-once) + kwargs.setdefault("enable_auto_commit", False) + + consumer = kafka.AIOKafkaConsumer( + *topics_subscribe, + bootstrap_servers=self._config.get("bootstrap_servers"), + group_id=group_id, + auto_offset_reset="earliest", + value_deserializer=lambda v: v.decode("utf-8") if v else None, + key_deserializer=lambda k: k.decode("utf-8") if k else None, + **kwargs, + ) + self._consumer = consumer + await consumer.start() + try: + async for msg in consumer: + raw_headers = msg.headers + headers = None + if raw_headers: + headers = [] + for k, v in raw_headers: + if isinstance(v, bytes): + try: + v = v.decode("utf-8") + except Exception: + pass + headers.append((k, v)) + + consumed_msg = ConsumedMessage( + topic=msg.topic, + partition=msg.partition, + offset=msg.offset, + key=msg.key, + value=msg.value, + headers=headers, + ) + LOG.debug("Received message: {}".format(consumed_msg)) + yield consumed_msg + finally: + await consumer.stop() + self._consumer = None + + async def commit(self, *args, **kwargs): + """Commit offsets for the current consumer.""" + if self._consumer is not None: + await self._consumer.commit(*args, **kwargs) + + async def close(self): + """Close all Kafka clients.""" + if self._producer: + await self._producer.stop() + self._producer = None + if self._consumer: + await self._consumer.stop() + self._consumer = None + if self._admin_client: + await self._admin_client.close() + self._admin_client = None + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/pkg/aloha/db/mongo_aio.py b/pkg/aloha/db/mongo_aio.py new file mode 100644 index 0000000..823eb97 --- /dev/null +++ b/pkg/aloha/db/mongo_aio.py @@ -0,0 +1,257 @@ +""" +Async MongoDB connection helpers. +""" + +import ipaddress +import json + +from motor.motor_asyncio import AsyncIOMotorClient + +from ..logger import LOG +from .base_aio import PasswordVault + +__all__ = ("MongoOperator",) + + +def _is_ip_addr(s): + try: + ipaddress.ip_address(s) + return True + except ValueError: + return False + + +_conn = {} + + +def MongoOperator(config): + """ + Return a cached async MongoDB operation wrapper for the given config. + + Note: This function returns the same async operator class (no caching needed for async). + The caching behavior is preserved for API compatibility but the underlying operations are async. + """ + db_name = config.get("db_name") + collection_name = config.get("collection_name") + + _config = {k: v for k, v in config.items() if v is not None} + key = "%s:%s:%s" % (json.dumps(_config, sort_keys=True, ensure_ascii=False), db_name or "", collection_name or "") + + if key not in _conn: + try: + _conn[key] = _MongoDBOperation(_config, db_name=db_name, collection_name=collection_name) + except Exception as e: + LOG.exception(e) + return + return _conn[key] + + +class _MongoDBOperation: + """Async MongoDB collection helper built on top of motor (async pymongo).""" + + def __init__(self, config, db_name=None, collection_name=None): + """Create an async MongoClient and optionally bind a default collection.""" + self.db_name, self.collection_name = db_name, collection_name + + host = config["host"] + + if config.get("port") is None and isinstance(host, list): + hosts = ["{host}:{port}".format(**h) for h in host] + else: + hosts = ["{host}:{port}".format(host=host, port=config.get("port", 27017))] + + replicaSet = config.get("replicaSet") + if replicaSet is None and not _is_ip_addr(hosts[0].split(":")[0]): + replicaSet = hosts[0].split(".")[0] + + password_vault = PasswordVault.get_vault_sync(config.get("vault_type"), config.get("vault_config")) + _config = { + "host": "mongodb://%s" % ",".join(hosts), + "port": config.get("port"), + "replicaSet": replicaSet, + "username": config["username"], + "password": password_vault.get_password(config.get("password")), + "maxPoolSize": config.get("maxPoolSize"), + "authSource": config.get("authSource", db_name), + } + msg = {k: ("***" if k == "password" else v) for k, v in _config.items()} + LOG.debug(msg) + + try: + self.conn: AsyncIOMotorClient = AsyncIOMotorClient(**_config) + + self.db = self.conn[db_name] + if self.collection_name is not None: + self.collection = self.db[self.collection_name] + except Exception as e: + LOG.exception(e) + + async def set_collection(self, collection_name): + """Switch the active collection after verifying it exists.""" + if collection_name not in await self.db.list_collection_names(): + raise Exception("Collection[%s] does not exist in [%s]" % (self.collection_name, self.db_name)) + self.collection_name = collection_name + self.collection = self.db[self.collection_name] + return True + + async def check_and_get_collection(self, collection_name=None, raise_if_not_exists=True): + """Return the active collection, switching it when requested.""" + self.db = self.conn[self.db_name] + + if self.collection_name is not None: + self.collection = self.db[self.collection_name] + + if collection_name is not None and collection_name != self.collection_name: + if self.collection_name not in await self.db.list_collection_names(): + if raise_if_not_exists: + raise Exception("Collection [%s] does not exist in [%s]" % (self.collection_name, self.db_name)) + else: + pass + + self.collection_name = collection_name + self.collection = self.db[self.collection_name] + + return self.collection + + async def insert(self, doc_or_docs, check_keys=False, collection_name=None): + """Insert a single document or a list of documents asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.insert_many(doc_or_docs, check_keys=check_keys) if isinstance(doc_or_docs, list) else await collection.insert_one(doc_or_docs) + except Exception as e: + LOG.exception(e) + + async def insert_many(self, docs, collection_name=None): + """Insert many documents at once asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.insert_many(docs) + except Exception as e: + LOG.exception(e) + + async def insert_one(self, doc, collection_name=None): + """Insert one document asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.insert_one(doc) + except Exception as e: + LOG.exception(e) + + async def delete_many(self, field_filter, collection_name=None): + """Delete all documents matching the filter asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.delete_many(filter=field_filter) + except Exception as e: + LOG.exception(e) + + async def delete_one(self, field_filter, collection_name=None): + """Delete one document matching the filter asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.delete_one(filter=field_filter) + except Exception as e: + LOG.exception(e) + + async def update_one( + self, + field_filter, + update, + upsert=False, + bypass_document_validation=False, + collation=None, + array_filters=None, + session=None, + collection_name=None, + ): + """Update one document and return whether the update succeeded.""" + try: + collection = await self.check_and_get_collection(collection_name) + await collection.update_one( + filter=field_filter, + update=update, + upsert=upsert, + bypass_document_validation=bypass_document_validation, + collation=collation, + array_filters=array_filters, + session=session, + ) + return True + except Exception as e: + LOG.exception(e) + return False + + async def update_many( + self, + field_filter, + update, + upsert=False, + bypass_document_validation=False, + collation=None, + array_filters=None, + session=None, + collection_name=None, + ): + """Update many documents matching the filter asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.update_many( + filter=field_filter, + update=update, + upsert=upsert, + bypass_document_validation=bypass_document_validation, + collation=collation, + array_filters=array_filters, + session=session, + ) + except Exception as e: + LOG.exception(e) + + async def query(self, field_filter=None, sort=None, limit=40, skip=0, collection_name=None): + """Query documents with optional sorting, limit, and skip asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + if sort: + cursor = collection.find(field_filter or {}).sort(sort).skip(skip).limit(limit) + else: + cursor = collection.find(field_filter or {}).skip(skip).limit(limit) + return await cursor.to_list(length=limit) + except Exception as e: + LOG.exception(e) + + async def find_many(self, field_filter=None, projection=None, collection_name=None, *args, **kwargs): + """Return a cursor for a MongoDB query asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + cursor = collection.find(field_filter or {}, projection, *args, **kwargs) + return await cursor.to_list(length=None) + except Exception as e: + LOG.exception(e) + + async def find_one(self, field_filter=None, projection=None, collection_name=None, *args, **kwargs): + """Return a single matching MongoDB document asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.find_one(field_filter or {}, projection, *args, **kwargs) + except Exception as e: + LOG.exception(e) + + async def count(self, field_filter=None, collection_name=None): + """Count documents matching the filter asynchronously.""" + try: + collection = await self.check_and_get_collection(collection_name) + return await collection.count_documents(field_filter or {}) + except Exception as e: + LOG.exception(e) + + async def check_connected(self): + """Check if the connection is still active.""" + try: + await self.conn.admin.command("ping") + except Exception: + raise NameError("MongoDB: not connected") + + async def close(self): + """Close the MongoDB connection.""" + if self.conn: + self.conn.close() \ No newline at end of file diff --git a/pkg/aloha/db/mysql_aio.py b/pkg/aloha/db/mysql_aio.py new file mode 100644 index 0000000..74446c4 --- /dev/null +++ b/pkg/aloha/db/mysql_aio.py @@ -0,0 +1,103 @@ +""" +Async MySQL connection helpers. +""" + +import aiomysql +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.sql import text + +from ..logger import LOG +from .base_aio import PasswordVault + +__all__ = ("MySqlOperator",) + +LOG.debug("mysql_aio: using aiomysql for async MySQL support") + + +class MySqlOperator: + """Create and use an async SQLAlchemy-backed MySQL connection.""" + + def __init__(self, db_config, **kwargs): + """Build an async connection pool from the provided database config.""" + password_vault = PasswordVault.get_vault_sync(db_config.get("vault_type"), db_config.get("vault_config")) + self._config = { + "host": db_config["host"], + "port": db_config["port"], + "user": db_config["user"], + "password": password_vault.get_password(db_config["password"]), + "dbname": db_config["dbname"], + } + + try: + self.engine: AsyncEngine = create_async_engine( + "mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}".format(**self._config), + pool_size=50, + pool_recycle=500, + pool_pre_ping=True, + **kwargs, + ) + LOG.debug("MySQL (async) connected: {host}:{port}/{dbname}".format(**self._config)) + except Exception as e: + LOG.exception(e) + raise RuntimeError("Failed to connect to MySQL (async)") + + @property + def connection(self): + return self.engine + + async def execute_query(self, sql, *args, **kwargs): + """Execute a SQL statement asynchronously and return the cursor result.""" + async with self.engine.connect() as conn: + cur = await conn.execute(text(sql), *args, **kwargs) + return cur + + async def execute_query_scalars(self, sql, *args, **kwargs): + """Execute a SQL statement and return all scalar results.""" + async with self.engine.connect() as conn: + cur = await conn.stream_scalars(text(sql), *args, **kwargs) + async for row in cur: + yield row + + @property + def connection_str(self) -> str: + """Return a human-readable connection string.""" + return "mysql://{user}:{password}@{host}:{port}/{dbname}".format(**self._config) + + async def close(self): + """Close the async engine and all connections.""" + await self.engine.dispose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + + +async def create_pool(db_config, **kwargs) -> aiomysql.Pool: + """ + Create an aiomysql connection pool directly. + + Args: + db_config: Database configuration dict with host, port, user, password, dbname + **kwargs: Additional aiomysql pool arguments + + Returns: + aiomysql.Pool instance + """ + password_vault = PasswordVault.get_vault_sync(db_config.get("vault_type"), db_config.get("vault_config")) + password = password_vault.get_password(db_config["password"]) + + pool = await aiomysql.create_pool( + host=db_config["host"], + port=db_config["port"], + user=db_config["user"], + password=password, + db=db_config["dbname"], + minsize=kwargs.pop("minsize", 5), + maxsize=kwargs.pop("maxsize", 50), + **kwargs, + ) + return pool \ No newline at end of file diff --git a/pkg/aloha/db/oracle_aio.py b/pkg/aloha/db/oracle_aio.py new file mode 100644 index 0000000..871b254 --- /dev/null +++ b/pkg/aloha/db/oracle_aio.py @@ -0,0 +1,110 @@ +""" +Async Oracle DB connection helpers. +""" + +import oracledb +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.sql import text + +from ..logger import LOG +from .base_aio import PasswordVault + +__all__ = ("OracledbOperator",) + +LOG.debug("oracledb (async) version = %s" % oracledb.__version__) + + +class OracledbOperator: + """Create and use an async SQLAlchemy-backed Oracle connection.""" + + def __init__(self, db_config, **kwargs): + """Build an async Oracle connection pool from the provided config.""" + """example of db_config: + { + "host": "192.168.1.100", + "port": 1521, + "user": "PT_INDEX", + "password": "vault_key_or_plain", + "service_name": "orcl", # 推荐使用 service_name + "sid": "orcl", # 或使用 sid + "vault_type": "...", + "vault_config": {...}, + "lib_dir": "/opt/oracle/instantclient" # optional, use THICK mode if defined. + } + """ + + password_vault = PasswordVault.get_vault_sync(db_config.get("vault_type"), db_config.get("vault_config")) + self._config = { + "host": db_config["host"], + "port": db_config["port"], + "user": db_config["user"], + "password": password_vault.get_password(db_config.get("password")), + } + + if "lib_dir" in db_config: + try: + oracledb.init_oracle_client(lib_dir=db_config["lib_dir"]) + LOG.info("Oracle client initialized in THICK mode from: %s" % db_config["lib_dir"]) + except Exception as e: + LOG.warning(f"Warning: {e}") + raise RuntimeError(f"Failed to initialize Oracle client: {e}") + + service_name = db_config.get("service_name") + sid = db_config.get("sid") + + if service_name: + dsn = oracledb.makedsn(db_config["host"], db_config["port"], service_name=service_name) + elif sid: + dsn = oracledb.makedsn(db_config["host"], db_config["port"], sid=sid) + else: + raise ValueError("Oracle config must specify service_name or sid") + + self._config["dsn"] = dsn + try: + self.engine: AsyncEngine = create_async_engine( + "oracle+oracledb://{user}:{password}@".format(**self._config), + connect_args={"dsn": dsn}, + pool_size=20, + max_overflow=10, + pool_pre_ping=True, + **kwargs, + ) + msg = "OracleDB (async) connected: {host}:{port}".format(**self._config) + print(msg) + except Exception as e: + LOG.error(e) + raise RuntimeError("Failed to connect to OracleDB (async)") + + @property + def connection(self): + return self.engine + + async def execute_query(self, sql, *args, **kwargs): + """Execute a SQL statement asynchronously and return the cursor result.""" + async with self.engine.connect() as conn: + cur = await conn.execute(text(sql), *args, **kwargs) + return cur + + async def execute_query_scalars(self, sql, *args, **kwargs): + """Execute a SQL statement and return all scalar results.""" + async with self.engine.connect() as conn: + cur = await conn.stream_scalars(text(sql), *args, **kwargs) + async for row in cur: + yield row + + @property + def connection_str(self) -> str: + """Return a human-readable connection string.""" + return "oracle://{user}@{host}:{port} (async)".format(**self._config) + + async def close(self): + """Close the async engine and all connections.""" + await self.engine.dispose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() \ No newline at end of file diff --git a/pkg/aloha/db/postgres_aio.py b/pkg/aloha/db/postgres_aio.py new file mode 100644 index 0000000..185b83f --- /dev/null +++ b/pkg/aloha/db/postgres_aio.py @@ -0,0 +1,108 @@ +""" +Async PostgreSQL connection helpers. +""" + +import asyncpg +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.sql import text + +from ..logger import LOG +from .base_aio import PasswordVault + +__all__ = ("PostgresOperator",) + +LOG.debug("postgres_aio: using asyncpg for async PostgreSQL support") + + +class PostgresOperator: + """Create and use an async SQLAlchemy-backed PostgreSQL connection.""" + + def __init__(self, db_config, **kwargs): + """Build an async PostgreSQL connection pool from the database config.""" + password_vault = PasswordVault.get_vault_sync(db_config.get("vault_type"), db_config.get("vault_config")) + self._config = { + "host": db_config["host"], + "port": db_config["port"], + "user": db_config["user"], + "password": password_vault.get_password(db_config.get("password")), + "dbname": db_config["dbname"], + } + connect_args = {} + if "schema" in db_config: + connect_args["options"] = "-csearch_path={}".format(db_config["schema"]) + + try: + self.engine: AsyncEngine = create_async_engine( + "postgresql+asyncpg://{user}:{password}@{host}:{port}/{dbname}".format(**self._config), + connect_args=connect_args, + client_encoding="utf8", + pool_size=20, + max_overflow=10, + pool_pre_ping=True, + **kwargs, + ) + LOG.debug("PostgresSQL (async) connected: {host}:{port}/{dbname}".format(**self._config)) + except Exception as e: + LOG.error(e) + raise RuntimeError("Failed to connect to PostgresSQL (async)") + + @property + def connection(self): + return self.engine + + async def execute_query(self, sql, *args, **kwargs): + """Execute a SQL statement asynchronously and return the cursor result.""" + async with self.engine.connect() as conn: + cur = await conn.execute(text(sql), *args, **kwargs) + return cur + + async def execute_query_scalars(self, sql, *args, **kwargs): + """Execute a SQL statement and return all scalar results.""" + async with self.engine.connect() as conn: + cur = await conn.stream_scalars(text(sql), *args, **kwargs) + async for row in cur: + yield row + + @property + def connection_str(self) -> str: + """Return a human-readable connection string.""" + return "postgresql://{user}:{password}@{host}:{port}/{dbname}".format(**self._config) + + async def close(self): + """Close the async engine and all connections.""" + await self.engine.dispose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + + +async def create_pool(db_config, **kwargs) -> asyncpg.Pool: + """ + Create an asyncpg connection pool directly. + + Args: + db_config: Database configuration dict with host, port, user, password, dbname + **kwargs: Additional asyncpg pool arguments + + Returns: + asyncpg.Pool instance + """ + password_vault = PasswordVault.get_vault_sync(db_config.get("vault_type"), db_config.get("vault_config")) + password = password_vault.get_password(db_config.get("password")) + + pool = await asyncpg.create_pool( + host=db_config["host"], + port=db_config["port"], + user=db_config["user"], + password=password, + database=db_config["dbname"], + min_size=kwargs.pop("min_size", 5), + max_size=kwargs.pop("max_size", 20), + **kwargs, + ) + return pool \ No newline at end of file diff --git a/pkg/aloha/db/redis_aio.py b/pkg/aloha/db/redis_aio.py new file mode 100644 index 0000000..4160321 --- /dev/null +++ b/pkg/aloha/db/redis_aio.py @@ -0,0 +1,86 @@ +""" +Async Redis connection helpers. +""" + +import redis.asyncio as redis +from packaging import version + +from ..logger import LOG +from .base_aio import PasswordVault + +__all__ = ("RedisOperator",) + + +class RedisOperator: + """Create async Redis connections with version-checked redis-py.""" + + def __init__(self, config): + """Normalize Redis connection settings and build connection metadata.""" + self._check_redis_version() + + password_vault = PasswordVault.get_vault_sync(config.get("vault_type"), config.get("vault_config")) + _config = { + "host": config["host"], + "port": config.get("port", "6379"), + "password": password_vault.get_password(config.get("password", None)), + "decode_responses": config.get("decode_responses", True), + "retry_on_timeout": True, + "max_connections": config.get("max_connections", 1000), + "socket_timeout": 3, + "socket_connect_timeout": 1, + } + if "db" in config: + _config["db"] = config["db"] + self._config = _config + + self._pool = None + + @staticmethod + def _check_redis_version() -> bool: + """Ensure a redis-py version new enough for the helpers is installed.""" + ver_min = version.parse("4.1.0") + valid = False + try: + ver_cur = version.parse(redis.__version__) + if ver_cur >= ver_min: + valid = True + LOG.debug("Using redis (async) version = %s" % redis.__version__) + except Exception as e: + LOG.error("Failed to obtain redis version!") + LOG.error(str(e)) + + if not valid: + msg = "Invalid version of `redis-py`, version >4.1.0 required for async support!" + LOG.fatal(msg) + raise ImportError(msg) + + return valid + + @property + def connection_generic(self): + """Return a standard async Redis client.""" + LOG.debug("AsyncRedis connection info: {host}:{port}".format(**self._config)) + + if self._pool is None: + self._pool = redis.ConnectionPool() + return redis.Redis(connection_pool=self._pool, **self._config) + + @property + def connection_cluster(self): + """Return an async Redis Cluster client.""" + LOG.debug("AsyncRedisCluster connection info: {host}:{port}".format(**self._config)) + return redis.RedisCluster(**self._config) + + async def close(self): + """Close the connection pool.""" + if self._pool is not None: + await self._pool.disconnect() + self._pool = None + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() \ No newline at end of file diff --git a/pkg/aloha/db/sqlite_aio.py b/pkg/aloha/db/sqlite_aio.py new file mode 100644 index 0000000..3532ae1 --- /dev/null +++ b/pkg/aloha/db/sqlite_aio.py @@ -0,0 +1,79 @@ +""" +Async SQLite connection helpers. +""" + +import sqlite3 + +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.sql import text + +from ..logger import LOG +from .base_aio import PasswordVault + +__all__ = ("SqliteOperator",) + + +class SqliteOperator: + """Create and use an async SQLAlchemy-backed SQLite connection.""" + + def __init__(self, db_config, **kwargs): + """Build an async SQLite or SQLCipher engine from the provided config.""" + self._connection_pattern = "sqlite+aiosqlite://{dbname}" + dbname = db_config.get("dbname", "") + if len(dbname) > 0: + dbname = "/%s" % dbname + self._config = {"dbname": dbname} + + if "password" in db_config: + try: + import sqlcipher3 + except ImportError: + raise RuntimeError("Python package required for encrypted sqlite3: sqlcipher3-binary") + LOG.debug("Version of sqlcipher3 = %s" % sqlcipher3.sqlite_version) + password_vault = PasswordVault.get_vault_sync(db_config.get("vault_type"), db_config.get("vault_config")) + password = password_vault.get_password(db_config.get("password", None)) + self._config["password"] = password + self._connection_pattern = "sqlite+pysqlcipher://:{password}@/{dbname}" + else: + LOG.debug("Version of sqlite = %s" % sqlite3.sqlite_version) + + try: + self.engine: AsyncEngine = create_async_engine(self._connection_pattern.format(**self._config), **kwargs) + LOG.debug("Sqlite (async) connected: %s" % self.connection_str) + except Exception as e: + LOG.exception(e) + raise RuntimeError("Failed to connect to sqlite (async)") + + @property + def connection(self): + return self.engine + + async def execute_query(self, sql, *args, **kwargs): + """Execute a SQL statement asynchronously and return the cursor result.""" + async with self.engine.connect() as conn: + cur = await conn.execute(text(sql), *args, **kwargs) + return cur + + async def execute_query_scalars(self, sql, *args, **kwargs): + """Execute a SQL statement and return all scalar results.""" + async with self.engine.connect() as conn: + cur = await conn.stream_scalars(text(sql), *args, **kwargs) + async for row in cur: + yield row + + @property + def connection_str(self) -> str: + """Return the SQLAlchemy connection URL used by the engine.""" + return self._connection_pattern.format(**self._config) + + async def close(self): + """Close the async engine and all connections.""" + await self.engine.dispose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() \ No newline at end of file diff --git a/pkg/pyproject.toml b/pkg/pyproject.toml index 5032152..e63b47c 100644 --- a/pkg/pyproject.toml +++ b/pkg/pyproject.toml @@ -49,6 +49,18 @@ db = [ "duckdb-engine", "oracledb", ] +# Async database drivers for aio modules +aio = [ + "sqlalchemy[asyncio]", + "asyncpg", + "aiomysql", + "redis[hiredis]", + "motor", + "elasticsearch[async]", + "aiokafka", + "aiosqlite", + "aioduckdb", +] stream = ["confluent_kafka"] data = ["pandas", "lxml"] report = ["openpyxl", "XlsxWriter"] @@ -79,7 +91,17 @@ all = [ "mkdocs", "mkdocstrings[python]", "markdown-include", - "mkdocs-material" + "mkdocs-material", + # Async database drivers + "sqlalchemy[asyncio]", + "asyncpg", + "aiomysql", + "redis[hiredis]", + "motor", + "elasticsearch[async]", + "aiokafka", + "aiosqlite", + "aioduckdb", ] [tool.setuptools] diff --git a/src/pyproject.toml b/src/pyproject.toml index 62ede2c..7d75bb2 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ # Extras: stream "confluent_kafka", + "aiokafka", # Extras: data & report "pandas", diff --git a/src/tests/test_kafka.py b/src/tests/test_kafka.py new file mode 100644 index 0000000..caf0816 --- /dev/null +++ b/src/tests/test_kafka.py @@ -0,0 +1,206 @@ +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch +from aloha.testing.unit import UnitTestCase +from aloha.db.kafka import KafkaOperator, ConsumedMessage, _unpack_message, _prepare_headers +from aloha.db.kafka_aio import KafkaOperator as KafkaOperatorAio, ConsumedMessage as ConsumedMessageAio + + +class TestKafkaHelpers(UnitTestCase): + def test_unpack_message(self): + # 1. string input + val, key, headers = _unpack_message("hello") + self.assertEqual(val, "hello") + self.assertIsNone(key) + self.assertIsNone(headers) + + # 2. dict input + val, key, headers = _unpack_message({"value": "hello", "key": "k", "headers": {"h": "v"}}) + self.assertEqual(val, "hello") + self.assertEqual(key, "k") + self.assertEqual(headers, {"h": "v"}) + + # 3. object/ConsumedMessage input + msg = ConsumedMessage(topic="t", partition=0, offset=0, key="k", value="hello", headers=[("h", "v")]) + val, key, headers = _unpack_message(msg) + self.assertEqual(val, "hello") + self.assertEqual(key, "k") + self.assertEqual(headers, [("h", "v")]) + + def test_prepare_headers(self): + # dict headers + hdrs = _prepare_headers({"tenant_id": "123", "trace_id": "abc"}) + self.assertEqual(hdrs, [("tenant_id", b"123"), ("trace_id", b"abc")]) + + # list headers with string and bytes values + hdrs = _prepare_headers([("tenant_id", "123"), ("trace_id", b"abc"), ("empty", None)]) + self.assertEqual(hdrs, [("tenant_id", b"123"), ("trace_id", b"abc"), ("empty", b"")]) + + +class TestSyncKafkaOperator(UnitTestCase): + @patch("confluent_kafka.Producer") + def test_producer_singleton_and_deliver(self, mock_producer_cls): + mock_producer = MagicMock() + mock_producer_cls.return_value = mock_producer + + config = {"host": [{"host": "localhost", "port": 9092}]} + op = KafkaOperator(config) + + # Producer is lazily created + p1 = op.producer() + p2 = op.producer() + self.assertIs(p1, p2) + mock_producer_cls.assert_called_once() + + # Test producer_deliver + generator = ["msg1", {"value": "msg2", "key": "k2", "headers": {"h2": "v2"}}] + op.producer_deliver("my_topic", generator) + + # Check mock calls + self.assertEqual(mock_producer.produce.call_count, 2) + # First call: string only + mock_producer.produce.assert_any_call( + "my_topic", value=b"msg1", key=None, headers=None, callback=mock_producer.produce.call_args_list[0][1]["callback"] + ) + # Second call: dict with key and headers + mock_producer.produce.assert_any_call( + "my_topic", value=b"msg2", key="k2", headers=[("h2", b"v2")], callback=mock_producer.produce.call_args_list[1][1]["callback"] + ) + mock_producer.flush.assert_called_once() + + # Context manager exit calls close() and flushes/cleans up producer + op.close() + self.assertIsNone(op._producer) + + @patch("confluent_kafka.Consumer") + def test_consumer_generator_and_commit(self, mock_consumer_cls): + mock_consumer = MagicMock() + mock_consumer_cls.return_value = mock_consumer + + # Mocking poll to return one message, then raise exception to exit + mock_msg = MagicMock() + mock_msg.topic.return_value = "my_topic" + mock_msg.partition.return_value = 0 + mock_msg.offset.return_value = 100 + mock_msg.key.return_value = b"my_key" + mock_msg.value.return_value = b"my_value" + mock_msg.headers.return_value = [("h", b"v")] + mock_msg.error.return_value = None + + mock_consumer.poll.side_effect = [mock_msg, KeyboardInterrupt("stop")] + + config = {"host": [{"host": "localhost", "port": 9092}]} + op = KafkaOperator(config) + + gen = op.consumer_generator(["my_topic"], group_id="my_group") + try: + msg = next(gen) + self.assertIsInstance(msg, ConsumedMessage) + self.assertEqual(msg.topic, "my_topic") + self.assertEqual(msg.partition, 0) + self.assertEqual(msg.offset, 100) + self.assertEqual(msg.key, "my_key") + self.assertEqual(msg.value, "my_value") + self.assertEqual(msg.headers, [("h", "v")]) + + # Commit calls mock consumer commit + op.commit() + mock_consumer.commit.assert_called_once() + + # Continue generator to trigger KeyboardInterrupt/finally block + next(gen) + except KeyboardInterrupt: + pass + + # Ensure consumer is closed and cleared + mock_consumer.close.assert_called_once() + self.assertIsNone(op._consumer) + + +class TestAsyncKafkaOperator(UnitTestCase): + def test_producer_singleton_and_deliver(self): + async def run_test(): + with patch("aiokafka.AIOKafkaProducer") as mock_producer_cls: + mock_producer = MagicMock() + mock_producer.start = AsyncMock() + mock_producer.stop = AsyncMock() + mock_producer.send_and_wait = AsyncMock() + mock_producer_cls.return_value = mock_producer + + config = {"host": [{"host": "localhost", "port": 9092}]} + op = KafkaOperatorAio(config) + + # Producer singleton + p1 = await op.producer() + p2 = await op.producer() + self.assertIs(p1, p2) + mock_producer.start.assert_called_once() + + # Test producer_deliver + async def mock_generator(): + yield "msg1" + yield {"value": "msg2", "key": "k2", "headers": {"h2": "v2"}} + + mock_callback = AsyncMock() + await op.producer_deliver("my_topic", mock_generator(), func_callback=mock_callback) + + self.assertEqual(mock_producer.send_and_wait.call_count, 2) + mock_producer.send_and_wait.assert_any_call( + "my_topic", value="msg1", key=None, headers=None + ) + mock_producer.send_and_wait.assert_any_call( + "my_topic", value="msg2", key="k2", headers=[("h2", b"v2")] + ) + # Check that callback was called + self.assertEqual(mock_callback.call_count, 2) + + # Close stops producer + await op.close() + mock_producer.stop.assert_called_once() + self.assertIsNone(op._producer) + + asyncio.run(run_test()) + + def test_consumer_generator_and_commit(self): + async def run_test(): + with patch("aiokafka.AIOKafkaConsumer") as mock_consumer_cls: + mock_consumer = MagicMock() + mock_consumer.start = AsyncMock() + mock_consumer.stop = AsyncMock() + mock_consumer.commit = AsyncMock() + + # Mocking async iterator on consumer + mock_msg = MagicMock() + mock_msg.topic = "my_topic" + mock_msg.partition = 0 + mock_msg.offset = 100 + mock_msg.key = "my_key" + mock_msg.value = "my_value" + mock_msg.headers = [("h", b"v")] + + async def mock_aiter(self_iter): + yield mock_msg + + mock_consumer.__aiter__ = mock_aiter + mock_consumer_cls.return_value = mock_consumer + + config = {"host": [{"host": "localhost", "port": 9092}]} + op = KafkaOperatorAio(config) + + gen = op.consumer_generator(["my_topic"], group_id="my_group") + async for msg in gen: + self.assertIsInstance(msg, ConsumedMessageAio) + self.assertEqual(msg.topic, "my_topic") + self.assertEqual(msg.partition, 0) + self.assertEqual(msg.offset, 100) + self.assertEqual(msg.key, "my_key") + self.assertEqual(msg.value, "my_value") + self.assertEqual(msg.headers, [("h", "v")]) + + # Commit + await op.commit() + mock_consumer.commit.assert_called_once() + + mock_consumer.stop.assert_called_once() + self.assertIsNone(op._consumer) + + asyncio.run(run_test())