Files
crm.twinpol.com/modules/EcmInvoiceOuts/ai/worker.py
2025-08-22 15:56:47 +02:00

142 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# worker.py
import os, json, io, uuid
import datetime as dt
from typing import Dict, Any, List
import polars as pl
import pymysql
from tenacity import retry, wait_exponential, stop_after_attempt
from dotenv import load_dotenv
load_dotenv()
AI_MODEL = os.getenv("AI_MODEL", "gpt-5-pro")
AI_API_KEY = os.getenv("AI_API_KEY")
MYSQL_CONF = dict(
host=os.getenv("MYSQL_HOST", "localhost"),
user=os.getenv("MYSQL_USER", "root"),
password=os.getenv("MYSQL_PASSWORD", ""),
database=os.getenv("MYSQL_DB", "sales"),
cursorclass=pymysql.cursors.DictCursor,
)
def mysql_query(sql: str, params: tuple = ()) -> pl.DataFrame:
conn = pymysql.connect(**MYSQL_CONF)
try:
with conn.cursor() as cur:
cur.execute(sql, params)
rows = cur.fetchall()
finally:
conn.close()
return pl.from_dicts(rows)
def to_csv(df: pl.DataFrame) -> str:
buf = io.StringIO()
df.write_csv(buf)
return buf.getvalue()
SQL_KPIS_DAILY = """
SELECT DATE(invoice_date) AS d,
SUM(net_amount) AS revenue,
SUM(quantity) AS qty,
ROUND(100*SUM(net_amount - cost_amount)/NULLIF(SUM(net_amount),0), 2) AS gross_margin_pct,
ROUND(100*SUM(discount_amount)/NULLIF(SUM(gross_amount),0), 2) AS discount_pct
FROM fact_invoices
WHERE invoice_date BETWEEN %s AND %s
GROUP BY 1
ORDER BY 1;
"""
SQL_TOP_SEGMENTS = """
SELECT {axis} AS key,
ANY_VALUE({label}) AS label,
SUM(net_amount) AS revenue,
SUM(quantity) AS qty,
ROUND(100*SUM(net_amount - cost_amount)/NULLIF(SUM(net_amount),0), 2) AS gross_margin_pct,
ROUND(100*(SUM(net_amount) - LAG(SUM(net_amount)) OVER(ORDER BY 1))/
NULLIF(LAG(SUM(net_amount)) OVER(ORDER BY 1),0), 2) AS trend_30d
FROM fact_invoices
WHERE invoice_date BETWEEN DATE_SUB(%s, INTERVAL 60 DAY) AND %s
GROUP BY 1
ORDER BY revenue DESC
LIMIT %s;
"""
class AIClient:
def __init__(self, api_key: str): self.api_key = api_key
@retry(wait=wait_exponential(multiplier=1, min=1, max=20), stop=stop_after_attempt(6))
def structured_analysis(self, prompt: str, schema: Dict[str, Any]) -> Dict[str, Any]:
# TODO: PODMIEŃ na realne wywołanie modelu z "Structured Outputs"
raise NotImplementedError("Wire your model SDK here")
@retry(wait=wait_exponential(multiplier=1, min=1, max=20), stop=stop_after_attempt(6))
def batch_submit(self, ndjson_lines: List[str]) -> str:
# TODO: PODMIEŃ na rzeczywiste Batch API
raise NotImplementedError
def run_online(from_date: str, to_date: str, currency: str, axis: str, label: str, top_n: int, goal: str) -> Dict[str, Any]:
kpis = mysql_query(SQL_KPIS_DAILY, (from_date, to_date))
top = mysql_query(SQL_TOP_SEGMENTS.format(axis=axis, label=label), (from_date, to_date, top_n))
csv_blocks = ("## kpis_daily\n" + to_csv(kpis) + "\n\n" +
"## top_segments\n" + to_csv(top))
with open(os.path.join(os.path.dirname(__file__), "sales-analysis.schema.json"), "r", encoding="utf-8") as f:
schema = json.load(f)
prompt = f"""
Jesteś analitykiem sprzedaży. Otrzymasz: (a) kontekst, (b) dane.
Zwróć **wyłącznie** JSON zgodny ze schema.
Kontekst:
- Waluta: {currency}
- Zakres: {from_date}{to_date}
- Cel: {goal}
- Poziom segmentacji: {axis}
Dane (CSV):
{csv_blocks}
Wskazówki:
- Użyj danych jak są (nie wymyślaj liczb).
- W meta.scope wpisz opis zakresu i segmentacji.
- Jeśli brak anomalii anomalies: [].
- Kwoty do 2 miejsc, procenty do 1.
"""
ai = AIClient(AI_API_KEY)
result = ai.structured_analysis(prompt, schema)
out_dir = os.path.join(os.path.dirname(__file__), "out")
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, f"{uuid.uuid4()}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False)
return {"status": "ok", "path": out_path}
def run_batch(from_date: str, to_date: str, axis: str, label: str):
# Zgodnie z blueprintem generujemy linie NDJSON (skrót; pełny wariant w PDF)
# TODO: dodać realne wywołania batch_submit i zapisać ID/stan
raise NotImplementedError("Implement batch per blueprint")
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
sub = p.add_subparsers(dest="cmd")
o = sub.add_parser("online")
o.add_argument("from_date"); o.add_argument("to_date"); o.add_argument("currency")
o.add_argument("axis", choices=["sku_id","client_id","region_code"])
o.add_argument("label"); o.add_argument("top_n", type=int, nargs="?", default=50)
o.add_argument("goal")
b = sub.add_parser("batch")
b.add_argument("from_date"); b.add_argument("to_date"); b.add_argument("axis"); b.add_argument("label")
args = p.parse_args()
if args.cmd == "online":
print(run_online(args.from_date, args.to_date, args.currency, args.axis, args.label, args.top_n, args.goal))
elif args.cmd == "batch":
print(run_batch(args.from_date, args.to_date, args.axis, args.label))
else:
p.print_help()