Skip to content

Commit

Permalink
Merge pull request #7 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Make more settings configurable
  • Loading branch information
VikParuchuri authored Jan 2, 2024
2 parents 6c9f43a + 97c14d1 commit 4f96e6d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "texify"
version = "0.1.7"
version = "0.1.8"
description = "OCR for latex images"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
8 changes: 4 additions & 4 deletions texify/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from texify.output import postprocess


def batch_inference(images, model, processor, temperature=settings.TEMPERATURE):
def batch_inference(images, model, processor, temperature=settings.TEMPERATURE, max_tokens=settings.MAX_TOKENS):
images = [image.convert("RGB") for image in images]
encodings = processor(images=images, return_tensors="pt", add_special_tokens=False)
pixel_values = encodings["pixel_values"].to(settings.MODEL_DTYPE)
pixel_values = pixel_values.to(settings.TORCH_DEVICE_MODEL)
pixel_values = encodings["pixel_values"].to(model.dtype)
pixel_values = pixel_values.to(model.device)

additional_kwargs = {}
if temperature > 0:
Expand All @@ -16,7 +16,7 @@ def batch_inference(images, model, processor, temperature=settings.TEMPERATURE):

generated_ids = model.generate(
pixel_values=pixel_values,
max_new_tokens=settings.MAX_TOKENS,
max_new_tokens=max_tokens,
decoder_start_token_id=processor.tokenizer.bos_token_id,
**additional_kwargs,
)
Expand Down
10 changes: 4 additions & 6 deletions texify/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
from texify.settings import settings


def load_model():
config = get_config(settings.MODEL_CHECKPOINT)
def load_model(checkpoint=settings.MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
config = get_config(checkpoint)
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)

dtype = settings.MODEL_DTYPE
device = settings.TORCH_DEVICE_MODEL
model = VisionEncoderDecoderModel.from_pretrained(settings.MODEL_CHECKPOINT, config=config, torch_dtype=dtype)
model = VisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
model = model.to(device)
model = model.eval()
print(f"Loaded model to {device} with {dtype} dtype")
print(f"Loaded texify model to {device} with {dtype} dtype")
return model


Expand Down

0 comments on commit 4f96e6d

Please sign in to comment.