-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1b63602
commit 42bcc7e
Showing
6 changed files
with
293 additions
and
2 deletions.
There are no files selected for viewing
104 changes: 104 additions & 0 deletions
104
...ring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package com.edgechain.lib.endpoint.impl.llm; | ||
|
||
import com.edgechain.lib.endpoint.Endpoint; | ||
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; | ||
import com.edgechain.lib.request.ArkRequest; | ||
import com.edgechain.lib.retrofit.Llama2Service; | ||
import com.edgechain.lib.retrofit.client.RetrofitClientInstance; | ||
import com.edgechain.lib.rxjava.retry.RetryPolicy; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import io.reactivex.rxjava3.core.Observable; | ||
import org.modelmapper.ModelMapper; | ||
import retrofit2.Retrofit; | ||
|
||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public class LLamaQuickstart extends Endpoint { | ||
private final Retrofit retrofit = RetrofitClientInstance.getInstance(); | ||
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); | ||
private final ModelMapper modelMapper = new ModelMapper(); | ||
|
||
@JsonProperty("text_inputs") | ||
private String textInputs; | ||
@JsonProperty("return_full_text") | ||
private Boolean returnFullText; | ||
@JsonProperty("top_k") | ||
private Integer topK; | ||
|
||
private String chainName; | ||
private String callIdentifier; | ||
|
||
public LLamaQuickstart() { | ||
} | ||
|
||
public LLamaQuickstart(String url, RetryPolicy retryPolicy) { | ||
super(url, retryPolicy); | ||
this.returnFullText = false; | ||
this.topK = 50; | ||
} | ||
|
||
public LLamaQuickstart(String url, RetryPolicy retryPolicy, Boolean returnFullText, Integer topK) { | ||
super(url, retryPolicy); | ||
this.returnFullText = returnFullText; | ||
this.topK = topK; | ||
} | ||
|
||
public String getTextInputs() { | ||
return textInputs; | ||
} | ||
|
||
public void setTextInputs(String textInputs) { | ||
this.textInputs = textInputs; | ||
} | ||
|
||
public Boolean getReturnFullText() { | ||
return returnFullText; | ||
} | ||
|
||
public void setReturnFullText(Boolean returnFullText) { | ||
this.returnFullText = returnFullText; | ||
} | ||
|
||
public Integer getTopK() { | ||
return topK; | ||
} | ||
|
||
public void setTopK(Integer topK) { | ||
this.topK = topK; | ||
} | ||
|
||
public String getChainName() { | ||
return chainName; | ||
} | ||
|
||
public void setChainName(String chainName) { | ||
this.chainName = chainName; | ||
} | ||
|
||
public String getCallIdentifier() { | ||
return callIdentifier; | ||
} | ||
|
||
public void setCallIdentifier(String callIdentifier) { | ||
this.callIdentifier = callIdentifier; | ||
} | ||
|
||
public Observable<List<String>> chatCompletion( | ||
String inputs, String chainName, ArkRequest arkRequest) { | ||
|
||
LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class); | ||
mapper.setTextInputs(inputs); | ||
mapper.setChainName(chainName); | ||
return chatCompletion(mapper, arkRequest); | ||
} | ||
|
||
private Observable<List<String>> chatCompletion( | ||
LLamaQuickstart mapper, ArkRequest arkRequest) { | ||
|
||
if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); | ||
else mapper.setCallIdentifier("URI wasn't provided"); | ||
|
||
return Observable.fromSingle(this.llama2Service.llamaCompletion(mapper)); | ||
} | ||
} |
61 changes: 61 additions & 0 deletions
61
FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package com.edgechain.lib.llama2; | ||
|
||
import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; | ||
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; | ||
import com.edgechain.lib.llama2.request.LLamaCompletionRequest; | ||
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; | ||
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; | ||
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; | ||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import io.reactivex.rxjava3.core.Observable; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.beans.factory.annotation.Autowired; | ||
import org.springframework.http.HttpEntity; | ||
import org.springframework.http.HttpHeaders; | ||
import org.springframework.http.MediaType; | ||
import org.springframework.stereotype.Service; | ||
import org.springframework.web.client.RestTemplate; | ||
|
||
import java.util.List; | ||
|
||
@Service | ||
public class LLamaClient { | ||
@Autowired private ObjectMapper objectMapper; | ||
private final Logger logger = LoggerFactory.getLogger(getClass()); | ||
private final RestTemplate restTemplate = new RestTemplate(); | ||
|
||
public EdgeChain<List<String>> createChatCompletion( | ||
LLamaCompletionRequest request, LLamaQuickstart endpoint) { | ||
return new EdgeChain<>( | ||
Observable.create( | ||
emitter -> { | ||
try { | ||
|
||
logger.info("Logging ChatCompletion...."); | ||
|
||
logger.info("==============REQUEST DATA================"); | ||
logger.info(request.toString()); | ||
|
||
// Create headers | ||
HttpHeaders headers = new HttpHeaders(); | ||
headers.setContentType(MediaType.APPLICATION_JSON); | ||
HttpEntity<LLamaCompletionRequest> entity = new HttpEntity<>(request, headers); | ||
// | ||
String response = | ||
restTemplate.postForObject(endpoint.getUrl(), entity, String.class); | ||
|
||
List<String> chatCompletionResponse = | ||
objectMapper.readValue( | ||
response, new TypeReference<>() {}); | ||
emitter.onNext(chatCompletionResponse); | ||
emitter.onComplete(); | ||
|
||
} catch (final Exception e) { | ||
emitter.onError(e); | ||
} | ||
}), | ||
endpoint); | ||
} | ||
} |
85 changes: 85 additions & 0 deletions
85
.../edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package com.edgechain.lib.llama2.request; | ||
|
||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
import java.util.StringJoiner; | ||
|
||
public class LLamaCompletionRequest { | ||
@JsonProperty("text_inputs") | ||
private String textInputs; | ||
@JsonProperty("return_full_text") | ||
private Boolean returnFullText; | ||
@JsonProperty("top_k") | ||
private Integer topK; | ||
public LLamaCompletionRequest() {} | ||
|
||
public LLamaCompletionRequest(String textInputs, Boolean returnFullText, Integer topK) { | ||
this.textInputs = textInputs; | ||
this.returnFullText = returnFullText; | ||
this.topK = topK; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return new StringJoiner(", ", LLamaCompletionRequest.class.getSimpleName() + "{", "}") | ||
.add("\"text_inputs:\"" + textInputs) | ||
.add("\"return_full_text:\"" + returnFullText) | ||
.add("\"top_k:\"" + topK) | ||
.toString(); | ||
} | ||
|
||
public static LlamaSupportChatCompletionRequestBuilder builder() { | ||
return new LlamaSupportChatCompletionRequestBuilder(); | ||
} | ||
|
||
public String getTextInputs() { | ||
return textInputs; | ||
} | ||
|
||
public void setTextInputs(String textInputs) { | ||
this.textInputs = textInputs; | ||
} | ||
|
||
public Boolean getReturnFullText() { | ||
return returnFullText; | ||
} | ||
|
||
public void setReturnFullText(Boolean returnFullText) { | ||
this.returnFullText = returnFullText; | ||
} | ||
|
||
public Integer getTopK() { | ||
return topK; | ||
} | ||
|
||
public void setTopK(Integer topK) { | ||
this.topK = topK; | ||
} | ||
|
||
public static class LlamaSupportChatCompletionRequestBuilder { | ||
private String textInputs; | ||
private Boolean returnFullText; | ||
private Integer topK; | ||
|
||
private LlamaSupportChatCompletionRequestBuilder() {} | ||
|
||
public LlamaSupportChatCompletionRequestBuilder textInputs(String textInputs) { | ||
this.textInputs = textInputs; | ||
return this; | ||
} | ||
|
||
public LlamaSupportChatCompletionRequestBuilder returnFullText(Boolean returnFullText) { | ||
this.returnFullText = returnFullText; | ||
return this; | ||
} | ||
|
||
public LlamaSupportChatCompletionRequestBuilder topK(Integer topK){ | ||
this.topK = topK; | ||
return this; | ||
} | ||
|
||
public LLamaCompletionRequest build() { | ||
return new LLamaCompletionRequest(textInputs, returnFullText, topK); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
...edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/LlamaController.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package com.edgechain.service.controllers.llama2; | ||
|
||
import com.edgechain.lib.configuration.WebConfiguration; | ||
import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; | ||
import com.edgechain.lib.llama2.LLamaClient; | ||
import com.edgechain.lib.llama2.request.LLamaCompletionRequest; | ||
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; | ||
import io.reactivex.rxjava3.core.Single; | ||
import org.springframework.beans.factory.annotation.Autowired; | ||
import org.springframework.web.bind.annotation.PostMapping; | ||
import org.springframework.web.bind.annotation.RequestBody; | ||
import org.springframework.web.bind.annotation.RequestMapping; | ||
import org.springframework.web.bind.annotation.RestController; | ||
|
||
import java.util.List; | ||
|
||
@RestController("Service LlamaController") | ||
@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/llama") //"llama/completion" | ||
public class LlamaController { | ||
|
||
@Autowired private LLamaClient lLamaClient; | ||
|
||
@PostMapping(value = "/completion") | ||
public Single<List<String>> chatCompletion(@RequestBody LLamaQuickstart lLamaQuickstart) { | ||
|
||
LLamaCompletionRequest LLamaCompletionRequest = | ||
com.edgechain.lib.llama2.request.LLamaCompletionRequest.builder() | ||
.textInputs(lLamaQuickstart.getTextInputs()) | ||
.returnFullText(lLamaQuickstart.getReturnFullText()) | ||
.topK(lLamaQuickstart.getTopK()) | ||
.build(); | ||
|
||
EdgeChain<List<String>> edgeChain = | ||
lLamaClient.createChatCompletion(LLamaCompletionRequest, lLamaQuickstart); | ||
|
||
return edgeChain.toSingle(); | ||
} | ||
} |