@@ -701,37 +701,41 @@ def map_pipeline(
701701 """
702702 new_ops = []
703703 tokenizer = model .tokenizer
704- has_batches = batch_by is not None
705704 for op in self .ops :
706705 # check if the pipe has a "tokenizer" kwarg and update the kwargs if needed
707706 op = copy (op )
708707 if (
709- isinstance (op , MapOp )
710- and "tokenizer" in signature (op .pipe ).parameters
711- and "tokenizer" not in op .kwargs
712- ):
713- op .kwargs ["tokenizer" ] = tokenizer
714- elif (
715- isinstance (op , MapBatchesOp )
716- and hasattr (op .pipe , "batch_process" )
717- and "tokenizer" in signature (op .pipe .batch_process ).parameters
718- and "tokenizer" not in op .kwargs
719- ) or (
720- isinstance (op , MapBatchesOp )
721- and callable (op .pipe )
722- and "tokenizer" in signature (op .pipe ).parameters
723- and "tokenizer" not in op .kwargs
708+ (
709+ isinstance (op , MapOp )
710+ and "tokenizer" in signature (op .pipe ).parameters
711+ and "tokenizer" not in op .kwargs
712+ )
713+ or (
714+ isinstance (op , MapBatchesOp )
715+ and hasattr (op .pipe , "batch_process" )
716+ and "tokenizer" in signature (op .pipe .batch_process ).parameters
717+ and "tokenizer" not in op .kwargs
718+ )
719+ or (
720+ isinstance (op , MapBatchesOp )
721+ and callable (op .pipe )
722+ and "tokenizer" in signature (op .pipe ).parameters
723+ and "tokenizer" not in op .kwargs
724+ )
724725 ):
725726 op .kwargs ["tokenizer" ] = tokenizer
726- has_batches = True
727727 if isinstance (op , (MapOp , MapBatchesOp )):
728728 op .context ["tokenizer" ] = tokenizer
729729 new_ops .append (op )
730+ has_batches = batch_by is not None or (
731+ hasattr (p , "batch_process" ) for n , p in model .pipeline
732+ )
730733 new_ops .append (MapOp (model ._ensure_doc , {}))
731734 if has_batches :
732735 batch_size , batch_by = self .validate_batching (batch_size , batch_by )
733736 batch_by = batchify_fns .get (batch_by , batch_by )
734737 new_ops .append (BatchifyOp (batch_size , batch_by ))
738+
735739 for name , pipe in model .pipeline :
736740 if name not in model ._disabled :
737741 op = (
0 commit comments