forked from LikithMeruvu/Generative-code
-
Notifications
You must be signed in to change notification settings - Fork 0
/
clama_34b.py
121 lines (92 loc) · 4.33 KB
/
clama_34b.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import requests
import json
import time
@st.cache_data
def code_lama_34b(token,prompt,temp,top_p,seed):
invoke_url = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/df2bee43-fb69-42b9-9ee5-f4eabbeaf3a8"
headers = {
"Authorization": f"Bearer {token}",
"accept": "text/event-stream",
"content-type": "application/json",
}
payload = {
"messages": [
{
"content": f"{prompt}",
"role": "user"
}
],
"temperature": temp,
"top_p": top_p,
"max_tokens": 1024,
"seed": seed,
"stream": True
}
try:
response = requests.post(invoke_url, headers=headers, json=payload, stream=True)
# List to store content values
content_list = []
# Get the total content length
total_length = int(response.headers.get("content-length", 0))
# Initialize progress bar
progress_bar = st.progress(0)
# Initialize progress counter
progress_counter = 0
for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data:"):
try:
json_data = json.loads(decoded_line[5:])
content = json_data["choices"][0]["delta"]["content"]
content_list.append(content)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
# Update progress
if total_length > 0:
progress_counter += len(decoded_line)
progress_bar.progress(min(progress_counter / total_length, 1.0))
# Add a small delay to allow the progress bar to update smoothly
time.sleep(0.01)
except requests.RequestException as e:
print(f"Request Exception: {e}")
return None
# Now content_list contains all the 'content' values from the JSON data
response_text = "".join(content_list)
return response_text
def display_code_lama_34B(token):
st.markdown("<h1 style=text-align:center;'>Code Llama 34B</h1>", unsafe_allow_html=True)
# st.write("Code lama 34b hyper params")
with st.sidebar:
st.title("Parameters Tuning (34B)")
st.session_state.val = st.slider("Select Temperature", key="slider1", min_value=0.1, max_value=1.0, value=0.7, step=0.1,help="Less Temp = More precise\n,High temperature = Creative")
if st.session_state.val > 0.9:
st.session_state.val = 1.0
st.write('Temperature:', st.session_state.val)
st.session_state.val1 = st.slider("Select Top_P", key="slider2", min_value=0.1, max_value=1.0, value=0.5, step=0.1,help = "nucleus sampling probability threshold")
if st.session_state.val1 > 0.9:
st.session_state.val1 = 1.0
st.write('Top_P:', st.session_state.val1)
st.session_state.val2 = st.slider("Select Seed", key="slider3", min_value=1, max_value=1000, value=42, step=1,help = "influences the variability of generated content")
st.write('Seed:', st.session_state.val2)
if "messages2" not in st.session_state:
st.session_state["messages2"] =[]
for msg in st.session_state.messages2:
with st.chat_message(msg.get("role")):
st.write(msg.get("content"))
prompt = st.chat_input("Ask me anything related Coding:",max_chars=8000)
# if st.button("Generate Code"):
# result = code_lama_7b(prompt,st.session_state.val,st.session_state.val1,st.session_state.val2)
# message(f"{result}")
if prompt:
st.session_state.messages2.append({"role":"user","content":prompt})
with st.chat_message("user"):
st.write(prompt)
# with st.chat_message("assistant"):
# result = code_lama_7b(prompt,st.session_state.val,st.session_state.val1,st.session_state.val2)
# st.write(result)
result = code_lama_34b(token,prompt,st.session_state.val,st.session_state.val1,st.session_state.val2)
st.session_state.messages2.append({"role":"assistant","content":result})
with st.chat_message("assistant"):
st.write(result)