diff --git a/pyproject.toml b/pyproject.toml index 9d5927f..6235527 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "texify" -version = "0.1.7" +version = "0.1.8" description = "OCR for latex images" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/texify/inference.py b/texify/inference.py index 34fa8f2..935e44a 100644 --- a/texify/inference.py +++ b/texify/inference.py @@ -2,11 +2,11 @@ from texify.output import postprocess -def batch_inference(images, model, processor, temperature=settings.TEMPERATURE): +def batch_inference(images, model, processor, temperature=settings.TEMPERATURE, max_tokens=settings.MAX_TOKENS): images = [image.convert("RGB") for image in images] encodings = processor(images=images, return_tensors="pt", add_special_tokens=False) - pixel_values = encodings["pixel_values"].to(settings.MODEL_DTYPE) - pixel_values = pixel_values.to(settings.TORCH_DEVICE_MODEL) + pixel_values = encodings["pixel_values"].to(model.dtype) + pixel_values = pixel_values.to(model.device) additional_kwargs = {} if temperature > 0: @@ -16,7 +16,7 @@ def batch_inference(images, model, processor, temperature=settings.TEMPERATURE): generated_ids = model.generate( pixel_values=pixel_values, - max_new_tokens=settings.MAX_TOKENS, + max_new_tokens=max_tokens, decoder_start_token_id=processor.tokenizer.bos_token_id, **additional_kwargs, ) diff --git a/texify/model/model.py b/texify/model/model.py index 1087fb0..dc4db0c 100644 --- a/texify/model/model.py +++ b/texify/model/model.py @@ -10,16 +10,14 @@ from texify.settings import settings -def load_model(): - config = get_config(settings.MODEL_CHECKPOINT) +def load_model(checkpoint=settings.MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): + config = get_config(checkpoint) AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel) - dtype = settings.MODEL_DTYPE - device = settings.TORCH_DEVICE_MODEL - model = VisionEncoderDecoderModel.from_pretrained(settings.MODEL_CHECKPOINT, config=config, torch_dtype=dtype) + model = VisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) model = model.to(device) model = model.eval() - print(f"Loaded model to {device} with {dtype} dtype") + print(f"Loaded texify model to {device} with {dtype} dtype") return model