classDonutModel(PreTrainedModel): r""" Donut: an E2E OCR-free Document Understanding Transformer. The encoder maps an input document image into a set of embeddings, the decoder predicts a desired token sequence, that can be converted to a structured format, given a prompt and the encoder output embeddings """ config_class = DonutConfig base_model_prefix = "donut"
defforward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor, decoder_labels: torch.Tensor): """ Calculate a loss given an input image and a desired token sequence, the model will be trained in a teacher-forcing manner Args: image_tensors: (batch_size, num_channels, height, width) decoder_input_ids: (batch_size, sequence_length, embedding_dim) decode_labels: (batch_size, sequence_length) """ encoder_outputs = self.encoder(image_tensors) decoder_outputs = self.decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_outputs, labels=decoder_labels, ) return decoder_outputs ... ...
classSwinEncoder(nn.Module): r""" Donut encoder based on SwinTransformer Set the initial weights and configuration with a pretrained SwinTransformer and then modify the detailed configurations as a Donut Encoder Args: input_size: Input image size (width, height) align_long_axis: Whether to rotate image if height is greater than width window_size: Window size(=patch size) of SwinTransformer encoder_layer: Number of layers of SwinTransformer encoder name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local. otherwise, `swin_base_patch4_window12_384` will be set (using `timm`). """
classBARTDecoder(nn.Module): """ Donut Decoder based on Multilingual BART Set the initial weights and configuration with a pretrained multilingual BART model, and modify the detailed configurations as a Donut decoder Args: decoder_layer: Number of layers of BARTDecoder max_position_embeddings: The maximum sequence length to be trained name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local, otherwise, `hyunwoongko/asian-bart-ecjk` will be set (using `transformers`) """
self.model = MBartForCausalLM( config=MBartConfig( is_decoder=True, is_encoder_decoder=False, add_cross_attention=True, decoder_layers=self.decoder_layer, max_position_embeddings=self.max_position_embeddings, vocab_size=len(self.tokenizer), scale_embedding=True, add_final_layer_norm=True, ) ) self.model.forward = self.forward # to get cross attentions and utilize `generate` function
self.model.config.is_encoder_decoder = True# to get cross-attention self.add_special_tokens(["<sep/>"]) # <sep/> is used for representing a list in a JSON self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
# weight init with asian-bart ifnot name_or_path: bart_state_dict = MBartForCausalLM.from_pretrained("hyunwoongko/asian-bart-ecjk").state_dict() new_bart_state_dict = self.model.state_dict() for x in new_bart_state_dict: if x.endswith("embed_positions.weight") andself.max_position_embeddings != 1024: new_bart_state_dict[x] = torch.nn.Parameter( self.resize_bart_abs_pos_emb( bart_state_dict[x], self.max_position_embeddings + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 ) ) elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"): new_bart_state_dict[x] = bart_state_dict[x][: len(self.tokenizer), :] else: new_bart_state_dict[x] = bart_state_dict[x] self.model.load_state_dict(new_bart_state_dict)
loss = None if labels isnotNone: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))
ifnot return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss isnotNoneelse output
defrasterize_paper( pdf: Union[Path, bytes], outpath: Optional[Path] = None, dpi: int = 96, return_pil=False, pages=None, ) -> Optional[List[io.BytesIO]]: """ Rasterize a PDF file to PNG images. Args: pdf (Path): The path to the PDF file. outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None. dpi (int, optional): The output DPI. Defaults to 96. return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False. pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None. Returns: Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None. """ pils = [] if outpath isNone: return_pil = True try: ifisinstance(pdf, (str, Path)): pdf = pypdfium2.PdfDocument(pdf) if pages isNone: pages = range(len(pdf)) renderer = pdf.render( pypdfium2.PdfBitmap.to_pil, page_indices=pages, scale=dpi / 72, ) for i, image inzip(pages, renderer): if return_pil: page_bytes = io.BytesIO() image.save(page_bytes, "bmp") pils.append(page_bytes) else: image.save((outpath / ("%02d.png" % (i + 1))), "png") except Exception as e: logging.error(e) if return_pil: return pils