derek-thomas HF staff commited on
Commit
97ab62b
0 Parent(s):

Duplicate from derek-thomas/probabilistic-forecast

Browse files
Files changed (8) hide show
  1. .gitattributes +34 -0
  2. .gitignore +2 -0
  3. AirPassengers.csv +1 -0
  4. README.md +14 -0
  5. app.py +74 -0
  6. make_plot.py +114 -0
  7. packages.txt +0 -0
  8. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea
2
+ lightning_logs
AirPassengers.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ Month,#Passengers
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Probablistic Forecasting
3
+ emoji: 🐨
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: derek-thomas/probabilistic-forecast
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from gluonts.dataset.pandas import PandasDataset
4
+ from gluonts.dataset.split import split
5
+ from gluonts.torch.model.deepar import DeepAREstimator
6
+
7
+ from make_plot import plot_forecast, plot_train_test
8
+
9
+
10
+ def offset_calculation(prediction_length, rolling_windows, length):
11
+ row_offset = -1 * prediction_length * rolling_windows
12
+ if abs(row_offset) > 0.95 * length:
13
+ raise gr.Error("Reduce prediction_length * rolling_windows")
14
+ return row_offset
15
+
16
+
17
+ def preprocess(input_data, prediction_length, rolling_windows, progress=gr.Progress(track_tqdm=True)):
18
+ df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
19
+ row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
20
+ return plot_train_test(df.iloc[:row_offset], df.iloc[row_offset:])
21
+
22
+
23
+ def train_and_forecast(input_data, prediction_length, rolling_windows, epochs, progress=gr.Progress(track_tqdm=True)):
24
+ if not input_data:
25
+ raise gr.Error("Upload a file with the Upload button")
26
+ try:
27
+ df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
28
+ except AttributeError:
29
+ raise gr.Error("Upload a file with the Upload button")
30
+
31
+ row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
32
+
33
+ gluon_df = PandasDataset(df, target=df.columns[0])
34
+
35
+ training_data, test_gen = split(gluon_df, offset=row_offset)
36
+
37
+ model = DeepAREstimator(
38
+ prediction_length=prediction_length,
39
+ freq=gluon_df.freq,
40
+ trainer_kwargs=dict(max_epochs=epochs),
41
+ ).train(
42
+ training_data=training_data,
43
+ )
44
+
45
+ test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=rolling_windows)
46
+ forecasts = list(model.predict(test_data.input))
47
+ return plot_forecast(df, forecasts)
48
+
49
+
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("""
52
+ # How to use
53
+ Upload a univariate csv with the first column showing your dates and the second column having your data
54
+
55
+ # How it works
56
+ 1. Click **Upload** to upload your data
57
+ 2. Click **Run**
58
+ - This app will visualize your data and then train an estimator and show its predictions
59
+ """)
60
+ with gr.Accordion(label='Hyperparameters'):
61
+ with gr.Row():
62
+ prediction_length = gr.Number(value=12, label='Prediction Length', precision=0)
63
+ windows = gr.Number(value=3, label='Number of Windows', precision=0)
64
+ epochs = gr.Number(value=10, label='Number of Epochs', precision=0)
65
+ with gr.Row():
66
+ upload_btn = gr.UploadButton(label="Upload")
67
+ train_btn = gr.Button(label="Train and Forecast")
68
+ plot = gr.Plot()
69
+
70
+ upload_btn.upload(fn=preprocess, inputs=[upload_btn, prediction_length, windows], outputs=plot)
71
+ train_btn.click(fn=train_and_forecast, inputs=[upload_btn, prediction_length, epochs, windows], outputs=plot)
72
+
73
+ if __name__ == "__main__":
74
+ demo.queue().launch()
make_plot.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.graph_objects as go
6
+
7
+
8
+ def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
9
+ """
10
+ Plot the training and test datasets using Plotly.
11
+
12
+ Args:
13
+ df1 (pd.DataFrame): Train dataset
14
+ df2 (pd.DataFrame): Test dataset
15
+
16
+ Returns:
17
+ None
18
+ """
19
+
20
+ # Create a Plotly figure
21
+ fig = go.Figure()
22
+
23
+ # Add the first scatter plot with steelblue color
24
+ fig.add_trace(go.Scatter(
25
+ x=df1.index,
26
+ y=df1.iloc[:, 0],
27
+ mode='lines',
28
+ name='Training Data',
29
+ line=dict(color='steelblue'),
30
+ marker=dict(color='steelblue')
31
+ ))
32
+
33
+ # Add the second scatter plot with yellow color
34
+ fig.add_trace(go.Scatter(
35
+ x=df2.index,
36
+ y=df2.iloc[:, 0],
37
+ mode='lines',
38
+ name='Test Data',
39
+ line=dict(color='gold'),
40
+ marker=dict(color='gold')
41
+ ))
42
+
43
+ # Customize the layout
44
+ fig.update_layout(
45
+ title='Univariate Time Series',
46
+ xaxis=dict(title='Date'),
47
+ yaxis=dict(title='Value'),
48
+ showlegend=True,
49
+ template='plotly_white'
50
+ )
51
+ return fig
52
+
53
+
54
+ def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
55
+ """
56
+ Plot the true values and forecasts using Plotly.
57
+
58
+ Args:
59
+ df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
60
+ forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.
61
+
62
+ Returns:
63
+ go.Figure: Plotly figure object.
64
+ """
65
+
66
+ # Create a Plotly figure
67
+ fig = go.Figure()
68
+
69
+ # Add the true values trace
70
+ fig.add_trace(go.Scatter(
71
+ x=pd.to_datetime(df.index),
72
+ y=df.iloc[:, 0],
73
+ mode='lines',
74
+ name='True values',
75
+ line=dict(color='black')
76
+ ))
77
+
78
+ # Add the forecast traces
79
+ colors = ["green", "blue", "purple"]
80
+ for i, forecast in enumerate(forecasts):
81
+ color = colors[i]
82
+ for sample in forecast.samples:
83
+ fig.add_trace(go.Scatter(
84
+ x=forecast.index.to_timestamp(),
85
+ y=sample,
86
+ mode='lines',
87
+ opacity=0.15, # Adjust opacity to control visibility of individual samples
88
+ name=f'Forecast {i + 1}',
89
+ showlegend=False, # Hide the individual forecast series from the legend
90
+ hoverinfo='none', # Disable hover information for the forecast series
91
+ line=dict(color=color)
92
+ ))
93
+ # Add the average
94
+ mean_forecast = np.mean(forecast.samples, axis=0)
95
+ fig.add_trace(go.Scatter(
96
+ x=forecast.index.to_timestamp(),
97
+ y=mean_forecast,
98
+ mode='lines',
99
+ name=f'Mean Forecast',
100
+ line=dict(color='red', dash='dash')
101
+ ))
102
+
103
+ # Customize the layout
104
+ fig.update_layout(
105
+ title='Passenger Forecast',
106
+ xaxis=dict(title='Index'),
107
+ yaxis=dict(title='Passenger Count'),
108
+ showlegend=True,
109
+ legend=dict(x=0, y=1, font=dict(size=16)),
110
+ hovermode='x' # Enable x-axis hover for better interactivity
111
+ )
112
+
113
+ # Return the figure
114
+ return fig
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gluonts[torch,pro]
2
+ pandas
3
+ plotly