Skip to content

Commit

Permalink
Adding support for LLama API
Browse files Browse the repository at this point in the history
  • Loading branch information
hemantDwivedi committed Oct 7, 2023
1 parent 1b63602 commit 42bcc7e
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package com.edgechain.lib.endpoint.impl.llm;

import com.edgechain.lib.endpoint.Endpoint;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.request.ArkRequest;
import com.edgechain.lib.retrofit.Llama2Service;
import com.edgechain.lib.retrofit.client.RetrofitClientInstance;
import com.edgechain.lib.rxjava.retry.RetryPolicy;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.reactivex.rxjava3.core.Observable;
import org.modelmapper.ModelMapper;
import retrofit2.Retrofit;

import java.util.List;
import java.util.Objects;

public class LLamaQuickstart extends Endpoint {
private final Retrofit retrofit = RetrofitClientInstance.getInstance();
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class);
private final ModelMapper modelMapper = new ModelMapper();

@JsonProperty("text_inputs")
private String textInputs;
@JsonProperty("return_full_text")
private Boolean returnFullText;
@JsonProperty("top_k")
private Integer topK;

private String chainName;
private String callIdentifier;

public LLamaQuickstart() {
}

public LLamaQuickstart(String url, RetryPolicy retryPolicy) {
super(url, retryPolicy);
this.returnFullText = false;
this.topK = 50;
}

public LLamaQuickstart(String url, RetryPolicy retryPolicy, Boolean returnFullText, Integer topK) {
super(url, retryPolicy);
this.returnFullText = returnFullText;
this.topK = topK;
}

public String getTextInputs() {
return textInputs;
}

public void setTextInputs(String textInputs) {
this.textInputs = textInputs;
}

public Boolean getReturnFullText() {
return returnFullText;
}

public void setReturnFullText(Boolean returnFullText) {
this.returnFullText = returnFullText;
}

public Integer getTopK() {
return topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

public String getChainName() {
return chainName;
}

public void setChainName(String chainName) {
this.chainName = chainName;
}

public String getCallIdentifier() {
return callIdentifier;
}

public void setCallIdentifier(String callIdentifier) {
this.callIdentifier = callIdentifier;
}

public Observable<List<String>> chatCompletion(
String inputs, String chainName, ArkRequest arkRequest) {

LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class);
mapper.setTextInputs(inputs);
mapper.setChainName(chainName);
return chatCompletion(mapper, arkRequest);
}

private Observable<List<String>> chatCompletion(
LLamaQuickstart mapper, ArkRequest arkRequest) {

if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI());
else mapper.setCallIdentifier("URI wasn't provided");

return Observable.fromSingle(this.llama2Service.llamaCompletion(mapper));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.edgechain.lib.llama2;

import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.request.LLamaCompletionRequest;
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.reactivex.rxjava3.core.Observable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import java.util.List;

@Service
public class LLamaClient {
@Autowired private ObjectMapper objectMapper;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();

public EdgeChain<List<String>> createChatCompletion(
LLamaCompletionRequest request, LLamaQuickstart endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

logger.info("Logging ChatCompletion....");

logger.info("==============REQUEST DATA================");
logger.info(request.toString());

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<LLamaCompletionRequest> entity = new HttpEntity<>(request, headers);
//
String response =
restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<String> chatCompletionResponse =
objectMapper.readValue(
response, new TypeReference<>() {});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package com.edgechain.lib.llama2.request;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.StringJoiner;

public class LLamaCompletionRequest {
@JsonProperty("text_inputs")
private String textInputs;
@JsonProperty("return_full_text")
private Boolean returnFullText;
@JsonProperty("top_k")
private Integer topK;
public LLamaCompletionRequest() {}

public LLamaCompletionRequest(String textInputs, Boolean returnFullText, Integer topK) {
this.textInputs = textInputs;
this.returnFullText = returnFullText;
this.topK = topK;
}

@Override
public String toString() {
return new StringJoiner(", ", LLamaCompletionRequest.class.getSimpleName() + "{", "}")
.add("\"text_inputs:\"" + textInputs)
.add("\"return_full_text:\"" + returnFullText)
.add("\"top_k:\"" + topK)
.toString();
}

public static LlamaSupportChatCompletionRequestBuilder builder() {
return new LlamaSupportChatCompletionRequestBuilder();
}

public String getTextInputs() {
return textInputs;
}

public void setTextInputs(String textInputs) {
this.textInputs = textInputs;
}

public Boolean getReturnFullText() {
return returnFullText;
}

public void setReturnFullText(Boolean returnFullText) {
this.returnFullText = returnFullText;
}

public Integer getTopK() {
return topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

public static class LlamaSupportChatCompletionRequestBuilder {
private String textInputs;
private Boolean returnFullText;
private Integer topK;

private LlamaSupportChatCompletionRequestBuilder() {}

public LlamaSupportChatCompletionRequestBuilder textInputs(String textInputs) {
this.textInputs = textInputs;
return this;
}

public LlamaSupportChatCompletionRequestBuilder returnFullText(Boolean returnFullText) {
this.returnFullText = returnFullText;
return this;
}

public LlamaSupportChatCompletionRequestBuilder topK(Integer topK){
this.topK = topK;
return this;
}

public LLamaCompletionRequest build() {
return new LLamaCompletionRequest(textInputs, returnFullText, topK);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.edgechain.lib.retrofit;

import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import io.reactivex.rxjava3.core.Single;
Expand All @@ -9,6 +10,8 @@
import java.util.List;

public interface Llama2Service {
@POST(value = "llama2/chat-completion")
@POST(value = "llama/chat-completion")
Single<List<Llama2ChatCompletionResponse>> chatCompletion(@Body Llama2Endpoint llama2Endpoint);
@POST(value = "llama/completion")
Single<List<String>> llamaCompletion(@Body LLamaQuickstart llama2Endpoint);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.List;

@RestController("Service Llama2Controller")
@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/llama2")
@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/llama")
public class Llama2Controller {
@Autowired private ChatCompletionLogService chatCompletionLogService;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.edgechain.service.controllers.llama2;

import com.edgechain.lib.configuration.WebConfiguration;
import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.llama2.LLamaClient;
import com.edgechain.lib.llama2.request.LLamaCompletionRequest;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import io.reactivex.rxjava3.core.Single;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.List;

@RestController("Service LlamaController")
@RequestMapping(value = WebConfiguration.CONTEXT_PATH + "/llama") //"llama/completion"
public class LlamaController {

@Autowired private LLamaClient lLamaClient;

@PostMapping(value = "/completion")
public Single<List<String>> chatCompletion(@RequestBody LLamaQuickstart lLamaQuickstart) {

LLamaCompletionRequest LLamaCompletionRequest =
com.edgechain.lib.llama2.request.LLamaCompletionRequest.builder()
.textInputs(lLamaQuickstart.getTextInputs())
.returnFullText(lLamaQuickstart.getReturnFullText())
.topK(lLamaQuickstart.getTopK())
.build();

EdgeChain<List<String>> edgeChain =
lLamaClient.createChatCompletion(LLamaCompletionRequest, lLamaQuickstart);

return edgeChain.toSingle();
}
}

0 comments on commit 42bcc7e

Please sign in to comment.