Source code for DECIMER.decimer

import logging
import os
import pickle
import sys
from typing import List
from typing import Tuple

import numpy as np
import pystow
import tensorflow as tf

import DECIMER.pre_process as pre_process
import DECIMER.utils as utils

# Silence tensorflow model loading warnings.
logging.getLogger("absl").setLevel("ERROR")

# Silence tensorflow errors - not recommended if your model is not working properly.
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


# ============================================================================
# Keras 2 to Keras 3 Compatibility Layer
# ============================================================================
[docs] class KerasCompatibilityUnpickler(pickle.Unpickler): """Custom unpickler to handle Keras 2.x tokenizers in Keras 3 environment. This handles the module path changes between Keras 2 and Keras 3: - keras.preprocessing.text -> tensorflow.keras.preprocessing.text """
[docs] def find_class(self, module, name): # Redirect old Keras 2.x imports to TensorFlow Keras if module.startswith("keras."): module = module.replace("keras.", "tensorflow.keras.") return super().find_class(module, name)
[docs] def load_tokenizer(tokenizer_path: str) -> object: """Load tokenizer with Keras 2/3 compatibility. Args: tokenizer_path (str): Path to the pickled tokenizer file Returns: tokenizer: Loaded tokenizer object """ try: # Try standard pickle load first (for Keras 3 tokenizers) with open(tokenizer_path, "rb") as f: return pickle.load(f) except ModuleNotFoundError as e: if "keras.preprocessing" in str(e): # Use compatibility unpickler for Keras 2.x tokenizers with open(tokenizer_path, "rb") as f: return KerasCompatibilityUnpickler(f).load() else: raise
# Set the absolute path HERE = os.path.dirname(os.path.abspath(__file__)) # Set model to run on default GPU and allow memory to grow as much as needed. # This allows us to run multiple instances of inference in the same GPU. os.environ["CUDA_VISIBLE_DEVICES"] = "0" gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Set path default_path = pystow.join("DECIMER-V2") model_urls = { "DECIMER": "https://zenodo.org/record/8300489/files/models.zip", "DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip", }
[docs] def get_models(model_urls: dict): """Download and load models from the provided URLs. This function downloads models from the provided URLs to a default location, then loads tokenizers and TensorFlow saved models. Args: model_urls (dict): A dictionary containing model names as keys and their corresponding URLs as values. Returns: tuple: A tuple containing loaded tokenizer and TensorFlow saved models. - tokenizer (object): Tokenizer for DECIMER model. - DECIMER_V2 (tf.saved_model): TensorFlow saved model for DECIMER. - DECIMER_Hand_drawn (tf.saved_model): TensorFlow saved model for DECIMER HandDrawn. """ # Download models to a default location model_paths = utils.ensure_models(default_path=default_path, model_urls=model_urls) # Load tokenizers tokenizer_path = os.path.join( model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl" ) tokenizer = load_tokenizer(tokenizer_path) # Load DECIMER models DECIMER_V2 = tf.saved_model.load(model_paths["DECIMER"]) DECIMER_Hand_drawn = tf.saved_model.load(model_paths["DECIMER_HandDrawn"]) return tokenizer, DECIMER_V2, DECIMER_Hand_drawn
tokenizer, DECIMER_V2, DECIMER_Hand_drawn = get_models(model_urls)
[docs] def detokenize_output(predicted_array: int) -> str: """This function takes the predicted tokens from the DECIMER model and returns the decoded SMILES string. Args: predicted_array (int): Predicted tokens from DECIMER Returns: (str): SMILES representation of the molecule """ outputs = [tokenizer.index_word[i] for i in predicted_array[0].numpy()] prediction = ( "".join([str(elem) for elem in outputs]) .replace("<start>", "") .replace("<end>", "") ) return prediction
[docs] def detokenize_output_add_confidence( predicted_array: tf.Tensor, confidence_array: tf.Tensor, ) -> List[Tuple[str, float]]: """This function takes the predicted array of tokens as well as the confidence values returned by the Transformer Decoder and returns a list of tuples that contain each token of the predicted SMILES string and the confidence value. Args: predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens) Returns: str: SMILES string """ prediction_with_confidence = [ ( tokenizer.index_word[predicted_array[0].numpy()[i]], confidence_array[i].numpy(), ) for i in range(len(confidence_array)) ] # remove start and end tokens prediction_with_confidence_ = prediction_with_confidence[1:-1] decoded_prediction_with_confidence = list( [(utils.decoder(tok), conf) for tok, conf in prediction_with_confidence_] ) return decoded_prediction_with_confidence
[docs] def predict_SMILES( image_input: [str, np.ndarray], confidence: bool = False, hand_drawn: bool = False ) -> str: """Predicts SMILES representation of a molecule depicted in the given image. Args: image_input (str or np.ndarray): Path of chemical structure depiction image or a numpy array representing the image. confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction. hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn. Returns: str: SMILES representation of the molecule in the input image, optionally with confidence values. """ chemical_structure = pre_process.decode_image(image_input) model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2 predicted_tokens, confidence_values = model(tf.constant(chemical_structure)) predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens)) if confidence: predicted_SMILES_with_confidence = detokenize_output_add_confidence( predicted_tokens, confidence_values ) return predicted_SMILES, predicted_SMILES_with_confidence return predicted_SMILES
[docs] def main(): """This function take the path of the image as user input and returns the predicted SMILES as output in CLI. Agrs: str: image_path Returns: str: predicted SMILES """ if len(sys.argv) != 2: print("Usage: {} $image_path".format(sys.argv[0])) else: SMILES = predict_SMILES(sys.argv[1]) print(SMILES)
if __name__ == "__main__": main()