180 lines
6.4 KiB
Python
180 lines
6.4 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import xgboost as xgb
|
|
import mysql.connector
|
|
from mysql.connector import Error
|
|
from datetime import datetime, timedelta
|
|
import holidays # NIEUW: importeer holidays
|
|
|
|
# --- CONFIGURATIE ---
|
|
# BELANGRIJK: Verwijs naar je NIEUWE model
|
|
MODEL_FILE = 'price_forecast_model_1.5.linux.json'
|
|
TARGET = 'gemiddelde_prijs'
|
|
AANTAL_UUR_VOORSPELLEN = 72 # Hoeveel uur vooruit wil je kijken?
|
|
|
|
DB_CONFIG = {
|
|
'host': '192.168.178.201',
|
|
'user': 'energy_prices_user',
|
|
'password': 'kS9R*xp17ZwCD@CV&E^N',
|
|
'database': 'energy_prices',
|
|
'port': 3307
|
|
}
|
|
|
|
# NIEUW: Haal de lijst van features uit je model
|
|
# We hoeven de lijst niet meer handmatig te typen!
|
|
try:
|
|
print(f"Laden van model: {MODEL_FILE}...")
|
|
model = xgb.XGBRegressor()
|
|
model.load_model(MODEL_FILE)
|
|
FEATURES = model.feature_names_in_ # Pakt automatisch alle feature-namen
|
|
print(f"✅ Model succesvol geladen (verwacht {len(FEATURES)} features).")
|
|
except FileNotFoundError:
|
|
print(f"❌ Fout: Model bestand '{MODEL_FILE}' niet gevonden.")
|
|
print("Heb je het 'v1_5 (linux)' model al getraind en opgeslagen?")
|
|
exit()
|
|
except Exception as e:
|
|
print(f"❌ Fout bij laden model: {e}")
|
|
exit()
|
|
|
|
# Maak de feestdagen-checker klaar
|
|
nl_holidays = holidays.Netherlands(years=[datetime.now().year, datetime.now().year + 1])
|
|
|
|
|
|
def haal_data_uit_database(conn):
|
|
# ... (Deze functie is 100% IDENTIEK aan je vorige script)
|
|
# ... (Kopieer de 'haal_data_uit_database' functie hier)
|
|
print("💾 Data ophalen uit MySQL...")
|
|
|
|
query_hist = """
|
|
SELECT
|
|
w.datum_tijd, w.temperatuur, w.gevoelstemperatuur, w.neerslag,
|
|
w.wind_richting, w.wind_snelheid, w.bewolking, w.luchtdruk, w.luchtvochtigheid,
|
|
p_avg.gemiddelde_prijs
|
|
FROM
|
|
amersfoort_weer_uurlijks AS w
|
|
LEFT JOIN -- <--- DE OPLOSSING
|
|
(SELECT datetime, AVG(price) AS gemiddelde_prijs
|
|
FROM dynamic_price_data GROUP BY datetime) AS p_avg
|
|
ON w.datum_tijd = p_avg.datetime
|
|
WHERE
|
|
w.datum_tijd BETWEEN (UTC_TIMESTAMP() - INTERVAL 30 HOUR) AND UTC_TIMESTAMP()
|
|
ORDER BY
|
|
w.datum_tijd;
|
|
"""
|
|
|
|
query_toekomst = f"""
|
|
SELECT
|
|
datum_tijd, temperatuur, gevoelstemperatuur, neerslag,
|
|
wind_richting, wind_snelheid, bewolking, luchtdruk, luchtvochtigheid,
|
|
NULL AS gemiddelde_prijs
|
|
FROM
|
|
amersfoort_weer_uurlijks
|
|
WHERE
|
|
datum_tijd BETWEEN UTC_TIMESTAMP() AND (UTC_TIMESTAMP() + INTERVAL {AANTAL_UUR_VOORSPELLEN} HOUR)
|
|
ORDER BY
|
|
datum_tijd;
|
|
"""
|
|
|
|
try:
|
|
hist_df = pd.read_sql(query_hist, conn, index_col='datum_tijd', parse_dates=['datum_tijd'])
|
|
toekomst_df = pd.read_sql(query_toekomst, conn, index_col='datum_tijd', parse_dates=['datum_tijd'])
|
|
|
|
print(f"✅ {len(hist_df)} uur historie geladen.")
|
|
print(f"✅ {len(toekomst_df)} uur toekomstig weer geladen.")
|
|
|
|
# --- OPLOSSING HIER ---
|
|
# Vul 'gaten' in de historische prijsdata ALLEEN op hist_df
|
|
hist_df['gemiddelde_prijs'] = hist_df['gemiddelde_prijs'].ffill()
|
|
hist_df['gemiddelde_prijs'] = hist_df['gemiddelde_prijs'].bfill()
|
|
# --- EINDE OPLOSSING ---
|
|
|
|
# Combineer nu de gevulde historie met de lege toekomst
|
|
combined_df = pd.concat([hist_df, toekomst_df])
|
|
|
|
return combined_df.sort_index()
|
|
|
|
except Exception as e:
|
|
print(f"❌ Fout bij ophalen data: {e}")
|
|
return None
|
|
|
|
|
|
def maak_features_voor_uur(df, timestamp):
|
|
"""
|
|
MAAK FEATURES v1.5 - Deze functie is compleet VERNIEUWD
|
|
"""
|
|
features = {}
|
|
|
|
# Haal data op van het specifieke uur
|
|
data_nu = df.loc[timestamp]
|
|
|
|
# 1. Tijd-features (simpel)
|
|
features['maand'] = timestamp.month
|
|
features['dag_van_het_jaar'] = timestamp.dayofyear
|
|
|
|
# 2. Feestdag feature
|
|
features['is_feestdag'] = 1 if timestamp in nl_holidays else 0
|
|
|
|
# 3. Weer-features
|
|
weer_cols = ['temperatuur', 'gevoelstemperatuur', 'neerslag', 'wind_richting',
|
|
'wind_snelheid', 'bewolking', 'luchtdruk', 'luchtvochtigheid']
|
|
for col in weer_cols:
|
|
features[col] = data_nu[col]
|
|
|
|
# 4. Lag-features
|
|
features['prijs_1u_geleden'] = df.loc[timestamp - timedelta(hours=1)]['gemiddelde_prijs']
|
|
features['prijs_24u_geleden'] = df.loc[timestamp - timedelta(hours=24)]['gemiddelde_prijs']
|
|
|
|
# 5. Rolling-features
|
|
features['temp_avg_3u'] = df.loc[timestamp - timedelta(hours=2) : timestamp]['temperatuur'].mean()
|
|
features['prijs_avg_6u'] = df.loc[timestamp - timedelta(hours=5) : timestamp]['gemiddelde_prijs'].mean()
|
|
|
|
# 6. ONE-HOT ENCODING (Handmatig)
|
|
# Voeg alle 7 'dag_' kolommen toe, en zet de juiste op 1
|
|
for dag in range(7):
|
|
features[f'dag_{dag}'] = 1 if timestamp.dayofweek == dag else 0
|
|
|
|
# Voeg alle 24 'uur_' kolommen toe, en zet de juiste op 1
|
|
for uur in range(24):
|
|
features[f'uur_{uur}'] = 1 if timestamp.hour == uur else 0
|
|
|
|
# Converteer naar een DataFrame en gebruik de volgorde van het model
|
|
return pd.DataFrame([features], columns=FEATURES)
|
|
|
|
|
|
# --- START VAN HET SCRIPT ---
|
|
# (Dit deel is weer 100% identiek aan je vorige script)
|
|
try:
|
|
conn = mysql.connector.connect(**DB_CONFIG)
|
|
werk_df = haal_data_uit_database(conn)
|
|
|
|
if werk_df is not None:
|
|
te_voorspellen_tijden = werk_df[werk_df['gemiddelde_prijs'].isnull()].index
|
|
|
|
print(f"\n🧠 Start iteratieve voorspelling voor {len(te_voorspellen_tijden)} uur...")
|
|
voorspellingen = []
|
|
|
|
for timestamp in te_voorspellen_tijden:
|
|
features_nu = maak_features_voor_uur(werk_df, timestamp)
|
|
voorspelde_prijs = model.predict(features_nu)[0]
|
|
werk_df.loc[timestamp, 'gemiddelde_prijs'] = voorspelde_prijs
|
|
voorspellingen.append(voorspelde_prijs)
|
|
|
|
print("\n" + "="*70)
|
|
pd.set_option('display.max_rows', None) # Zorg dat we alles printen
|
|
print(f"--- VOORSPELDE PRIJZEN (komende {len(te_voorspellen_tijden)} uur) ---")
|
|
|
|
resultaat_df = pd.DataFrame({
|
|
'Voorspelde_Prijs': voorspellingen
|
|
}, index=te_voorspellen_tijden)
|
|
|
|
print(resultaat_df)
|
|
print("="*70)
|
|
|
|
except Error as e:
|
|
print(f"❌ Fout met MySQL verbinding: {e}")
|
|
except Exception as e:
|
|
print(f"❌ Een onverwachte fout is opgetreden: {e}")
|
|
finally:
|
|
if 'conn' in locals() and conn.is_connected():
|
|
conn.close()
|
|
print("\nVerbinding met MySQL gesloten.") |