Skip to content

Commit

Permalink
Merge pull request #46 from statgarten/kgh
Browse files Browse the repository at this point in the history
시계열 excel 확장자 추가 코드 및 server 수정
  • Loading branch information
MinDongRyul authored Oct 31, 2023
2 parents 2889203 + 777fd0c commit 19a27ab
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 33 deletions.
35 changes: 19 additions & 16 deletions fastapi_backend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
kobert-tokenizer @ git+https://github.com/SKTBrain/KoBERT.git#egg=kobert_tokenizer&subdirectory=kobert_hf
sentencepiece
streamlit
scikit-learn
pandas
numpy
pillow
soundfile
torch
torchvision
transformers
fastapi
python-multipart
uvicorn
torchaudio
librosa
pydub
sentencepiece==0.1.99
gdown==4.7.1
fastapi==0.104.0
librosa==0.10.1
openai==0.28.1
openpyxl==3.1.2
numpy==1.26.1
streamlit==1.28.0
scikit-learn==1.3.2
soundfile==0.12.1
torch==2.1.0
torchvision==0.16.0
torchaudio==2.1.0
transformers==4.34.1
pandas==2.1.2
pydub==0.25.1
python-multipart==0.0.6
pillow==10.1.0
uvicorn==0.23.2
6 changes: 5 additions & 1 deletion fastapi_backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,11 @@ async def time_train_endpoint(data_arranges:list[str],
label_data = 'Label_Data'

content = await file.read()
train_df = pd.read_csv(BytesIO(content), encoding='utf-8')
if file_ext == 'csv':
train_df = pd.read_csv(BytesIO(content), encoding='utf-8')
elif file_ext == 'xlsx':
train_df = pd.read_excel(BytesIO(content))


train_df[date] = pd.to_datetime(train_df[date])

Expand Down
14 changes: 8 additions & 6 deletions streamlit_frontend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
streamlit
pandas
pillow
python-multipart
streamlit-modal
python-dotenv
openpyxl==3.1.2
openai==0.28.1
streamlit==1.24.1
streamlit-modal==0.1.0
pandas==2.1.2
pillow==9.5.0
python-dotenv==1.0.0
python-multipart==0.0.6
26 changes: 16 additions & 10 deletions streamlit_frontend/time_series_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from dotenv import load_dotenv
import os
import openpyxl

from PIL import Image

Expand Down Expand Up @@ -246,7 +247,8 @@ def main():
image = Image.open('timeseries_sample_data/timeseries_forecasting_sample_image.png')
st.image(image, caption=None)

uploaded_file = st.file_uploader(upload_train_csv, accept_multiple_files=False, type=['csv'])
#uploaded_file = st.file_uploader(upload_train_csv, accept_multiple_files=False, type=['csv'])
uploaded_file = st.file_uploader(upload_train_csv, accept_multiple_files=False, type=['csv', 'xlsx'])

if uploaded_file or sample_start:

Expand All @@ -260,9 +262,12 @@ def main():
csv_files = {'file': csv_file_obj}

if uploaded_file.name.split('.')[-1] == 'csv':
train_df = pd.read_csv(uploaded_file, encoding='utf-8')
elif uploaded_file.name.split('.')[-1] == 'excel':
train_df = pd.read_excel(uploaded_file)
train_df = pd.read_csv(uploaded_file, encoding='utf-8')
elif uploaded_file.name.split('.')[-1] == 'xlsx':
train_df = pd.read_excel(uploaded_file)

file_ext = uploaded_file.name.split('.')[-1]


elif sample_start:
train_df = pd.read_csv('timeseries_sample_data/timeseries_forecasting_sample_data.csv', encoding='utf-8')
Expand Down Expand Up @@ -376,12 +381,13 @@ def main():
prev_window_size = int(window_size)
with st.spinner(training_model_spinner):
response = requests.post(f"{BACKEND_URL}/time_train", files=csv_files, data={'data_arranges':list(data_arrange),
'window_size':int(window_size),
'horizon_factor':int(horizon_factor),
'epoch':int(epoch),
'learning_rate':float(learning_rate),
'pred_col': int_col,
'date': date})
'window_size':int(window_size),
'horizon_factor':int(horizon_factor),
'epoch':int(epoch),
'learning_rate':float(learning_rate),
'pred_col': int_col,
'date': date,
'file_ext':file_ext})

if response.ok:
res = response.json()
Expand Down

0 comments on commit 19a27ab

Please sign in to comment.