-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from vnbl/docker-inference
Merge `docker-inference` branch into dev
- Loading branch information
Showing
32 changed files
with
459 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
src/ | ||
venv/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
POSTGRES_USER='fer' | ||
POSTGRES_PASSWORD='nanda' | ||
POSTGRES_HOST='localhost' | ||
POSTGRES_DATABASE='estaciones' | ||
POSTGRES_USER='postgres' | ||
POSTGRES_PASSWORD='password' | ||
POSTGRES_HOST='your_host' | ||
POSTGRES_DB='db_name' | ||
POSTGRES_PORT='5432' | ||
|
||
MYSQL_USER='fer' | ||
MYSQL_PASSWORD='nanda' | ||
MYSQL_HOST='localhost' | ||
MYSQL_DATABASE='estaciones_remote' | ||
MYSQL_USER='mysql_user' | ||
MYSQL_PASSWORD='secret_pass' | ||
MYSQL_HOST='mysql_host' | ||
MYSQL_DB='mysql_db' | ||
MYSQL_PORT='3306' | ||
|
||
AIRNOW_API_KEY='your_secret_airnow_api_key' | ||
PIPELINE_HOST='localhost' | ||
PIPELINE_PORT='6789' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,3 +24,6 @@ tmp* | |
|
||
# mage drafts | ||
etl-pipeline/drafts/ | ||
|
||
#mage | ||
.mage_data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
ARG PYTHON_VERSION=3.10-slim-buster | ||
|
||
FROM python:${PYTHON_VERSION} | ||
|
||
ENV PYTHONDONTWRITEBYTECODE=1 | ||
ENV PYTHONUNBUFFERED=1 | ||
|
||
# Set the working directory | ||
WORKDIR /app | ||
|
||
RUN apt-get update && \ | ||
apt-get install -y libgomp1 && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
# Copy only requirements.txt to the container | ||
COPY requirements.txt /app/ | ||
COPY run_app.sh /app/ | ||
|
||
# Install dependencies and clean up cache | ||
RUN set -ex && \ | ||
pip install --upgrade pip && \ | ||
pip install -r /app/requirements.txt && \ | ||
rm -rf /root/.cache/ | ||
|
||
# Copy the rest of the application files into the container | ||
COPY . /app/ | ||
|
||
# Ensure run_app.sh is executable | ||
RUN chmod +x /app/run_app.sh | ||
|
||
# Expose the port for the app | ||
EXPOSE 6789 | ||
|
||
# Define the default command | ||
CMD ["/app/run_app.sh"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
services: | ||
db: | ||
image: postgres:16.4-alpine3.20 | ||
restart: always | ||
container_name: db_aire | ||
env_file: | ||
- .env | ||
environment: | ||
POSTGRES_USER: ${POSTGRES_USER} | ||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} | ||
POSTGRES_DB: ${POSTGRES_DB} | ||
ports: | ||
- "${POSTGRES_PORT}:5432" | ||
volumes: | ||
- postgres_data:/var/lib/postgresql/data | ||
healthcheck: | ||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"] | ||
interval: 10s | ||
timeout: 5s | ||
retries: 5 | ||
|
||
pipeline: | ||
build: | ||
context: . | ||
dockerfile: Dockerfile | ||
restart: always | ||
environment: | ||
MAGE_DATA_DIR: ../.mage_data | ||
PIPELINE_POSTGRES_USER: ${POSTGRES_USER} | ||
PIPELINE_POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} | ||
PIPELINE_POSTGRES_DB: ${POSTGRES_DB} | ||
PIPELINE_POSTGRES_HOST: db | ||
PIPELINE_POSTGRES_PORT: 5432 | ||
PIPELINE_MYSQL_USER: ${MYSQL_USER} | ||
PIPELINE_MYSQL_PASSWORD: ${MYSQL_PASSWORD} | ||
PIPELINE_MYSQL_DB: ${MYSQL_DB} | ||
PIPELINE_MYSQL_HOST: ${MYSQL_HOST} | ||
PIPELINE_MYSQL_PORT: ${MYSQL_PORT} | ||
ports: | ||
- "${PIPELINE_PORT}:6789" | ||
depends_on: | ||
db: | ||
condition: service_healthy | ||
healthcheck: | ||
test: ["CMD-SHELL", "curl -s -o /dev/null -w '%{http_code}' http://localhost:${PIPELINE_PORT} | grep -q 200"] | ||
interval: 1m30s | ||
timeout: 30s | ||
retries: 5 | ||
start_period: 10s | ||
volumes: | ||
- /home/vnbl/data_retriever:/app | ||
|
||
volumes: | ||
postgres_data: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
INSERT INTO health_checks (run_date, station_id, last_reading_id, is_on) | ||
SELECT | ||
DATE_TRUNC('hour', '{{ execution_date }}'::timestamp) AS run_date, -- Truncate to the hour | ||
s.id AS station_id, | ||
sr.id AS last_reading_id, | ||
CASE | ||
WHEN sr.date_utc >= DATE_TRUNC('hour', '{{ execution_date }}'::timestamp) - INTERVAL '6 hours' THEN TRUE | ||
ELSE FALSE | ||
END AS is_on | ||
FROM | ||
stations s | ||
LEFT JOIN LATERAL ( | ||
SELECT sr.id, sr.date_utc | ||
FROM station_readings_gold sr | ||
WHERE sr.station = s.id | ||
ORDER BY sr.date_utc DESC | ||
LIMIT 1 | ||
) sr ON TRUE | ||
ORDER BY | ||
s.id; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
INSERT INTO inference_runs (run_date) | ||
VALUES ('{{ execution_date }}'::timestamp); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
-- Docs: https://docs.mage.ai/guides/sql-blocks | ||
WITH execution_date AS ( | ||
SELECT run_date | ||
FROM inference_runs | ||
ORDER BY id DESC | ||
LIMIT 1 | ||
) | ||
|
||
SELECT | ||
station_id as id | ||
FROM health_checks | ||
WHERE | ||
run_date = (SELECT date_trunc('hour', run_date) - INTERVAL '1 HOUR' FROM execution_date) | ||
AND is_on = TRUE | ||
AND station_id NOT IN (SELECT id FROM stations WHERE is_pattern_station = TRUE); |
2 changes: 1 addition & 1 deletion
2
etl-pipeline/custom/get_last_airnow_reading_silver_and_bbox.sql
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
SELECT s.id AS station_id, | ||
r.bbox | ||
FROM regions r | ||
LEFT JOIN stations s ON s.region = r.region_code | ||
LEFT JOIN stations s ON s.region = r.id | ||
WHERE s.is_pattern_station = TRUE | ||
GROUP BY s.id, r.bbox; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import pandas as pd | ||
import json | ||
from glob import glob | ||
from datetime import datetime | ||
import os | ||
from darts import TimeSeries | ||
from darts.models import LightGBMModel | ||
|
||
if 'custom' not in globals(): | ||
from mage_ai.data_preparation.decorators import custom | ||
if 'test' not in globals(): | ||
from mage_ai.data_preparation.decorators import test | ||
|
||
def prepare_data(data): | ||
data.drop(columns=['station', 'inference_run'], inplace=True) | ||
run_date = pd.to_datetime(data['run_date'].iloc[0], utc=True).tz_convert(None) | ||
data['date_utc'] = pd.to_datetime(data['date_utc'], utc=True).dt.tz_convert(None) | ||
data = data.drop_duplicates(subset='date_utc').sort_values(by='date_utc') | ||
min_date_utc = data['date_utc'].min() | ||
|
||
if pd.api.types.is_datetime64tz_dtype(min_date_utc): | ||
min_date_utc = min_date_utc.tz_convert(None) | ||
|
||
full_range = pd.date_range(start=min_date_utc, end=run_date, freq='H') | ||
|
||
data = data.set_index('date_utc').reindex(full_range).rename_axis('date_utc').reset_index() | ||
|
||
data.fillna(method='ffill', inplace=True) | ||
data.fillna(method='bfill', inplace=True) | ||
data = data[data['date_utc'] <= run_date] | ||
data.reset_index(drop=True, inplace=True) | ||
|
||
if 'run_date' in data.columns: | ||
data.drop(columns=['run_date'], inplace=True) | ||
return data | ||
|
||
def get_latest_model_path(model_dir, model_name, klogger): | ||
try: | ||
all_files = os.listdir(model_dir) | ||
model_files = [] | ||
for file_name in all_files: | ||
file_name_no_ext = os.path.splitext(file_name)[0] | ||
try: | ||
# Format: <YYYY-MM-DD_vX.X.X_model-6h/12h.pkl> | ||
date_part, version_part, model_part = file_name_no_ext.split('_') | ||
datetime.strptime(date_part, "%Y-%m-%d") | ||
|
||
if model_part == model_name and version_part.startswith('v'): | ||
version_numbers = version_part[1:].split('.') | ||
if all(num.isdigit() for num in version_numbers): | ||
model_files.append(file_name) | ||
except (ValueError, IndexError): | ||
klogger.warning(f'Unvalid filename: {file_name}') | ||
continue | ||
|
||
if not model_files: | ||
klogger.exception(f"No valid models found for {model_name} in directory {model_dir}") | ||
return None | ||
# Latest model first in array: | ||
model_files.sort(reverse=True) | ||
return os.path.join(model_dir, model_files[0]) | ||
except Exception as e: | ||
klogger.exception(f"An error occurred while getting latest model path: {e}") | ||
return None | ||
|
||
|
||
def load_models(klogger, model_dir='etl-pipeline/models/'): | ||
try: | ||
model_12h_path = get_latest_model_path(model_dir, 'model-12h', klogger) | ||
model_6h_path = get_latest_model_path(model_dir, 'model-6h', klogger) | ||
|
||
model_12h = LightGBMModel.load(model_12h_path) | ||
model_6h = LightGBMModel.load(model_6h_path) | ||
|
||
return model_6h, model_12h | ||
except Exception as e: | ||
klogger.exception(f'An error occurred while loading models: {e}') | ||
|
||
def predict_aqi(data, model, output_length, klogger): | ||
try: | ||
target = 'aqi_pm2_5' | ||
covariates = list(data.columns.drop(['date_utc'])) | ||
ts = TimeSeries.from_dataframe(data, time_col='date_utc', value_cols=[target] + covariates, freq='h') | ||
target_data = ts[target] | ||
covariates_data = ts[covariates] | ||
|
||
y_pred = model.predict(output_length, series=target_data, past_covariates=covariates_data) | ||
|
||
y_pred_series = y_pred.pd_series().round(0) | ||
result = [ | ||
{ | ||
"timestamp": timestamp.isoformat(), | ||
"value": int(value) | ||
} | ||
for timestamp, value in y_pred_series.items() | ||
] | ||
return result | ||
except Exception as e: | ||
klogger.exception(f'An error occurred while predicting aqi: {e}') | ||
|
||
@custom | ||
def transform_custom(data, *args, **kwargs): | ||
klogger = kwargs.get('logger') | ||
try: | ||
station = data['station'].iloc[0] | ||
inference_run = data['inference_run'].iloc[0] | ||
|
||
pred_data = prepare_data(data) | ||
|
||
aqi_df = pred_data[['date_utc','aqi_pm2_5']] | ||
aqi_json_list = aqi_df.apply(lambda row: {"timestamp": row['date_utc'].isoformat(), "value": int(row['aqi_pm2_5'])}, axis=1).tolist() | ||
aqi_json = json.dumps(aqi_json_list, indent=4) | ||
|
||
model_6h, model_12h = load_models(klogger=klogger) | ||
forecast_12h = predict_aqi(pred_data, model_12h, | ||
output_length=12, klogger=klogger) | ||
forecast_6h = predict_aqi(pred_data, model_6h, | ||
output_length=6, klogger=klogger) | ||
result_df = pd.DataFrame({ | ||
'inference_run': [inference_run], | ||
'station': [station], | ||
'aqi_input': [aqi_json], | ||
'forecasts_6h': [forecast_6h], | ||
'forecasts_12h': [forecast_12h] | ||
}) | ||
return result_df | ||
except Exception as e: | ||
klogger.exception(e) | ||
|
||
@test | ||
def test_output(output, *args) -> None: | ||
""" | ||
Template code for testing the output of the block. | ||
""" | ||
assert output is not None, 'The output is undefined' |
Oops, something went wrong.