=================================================================== FAILURES ===================================================================
________________________________ GreedySearchTextGenerationTest.test_model_compile_batched_ds_jit_compile_false ________________________________

self = <keras_nlp.utils.text_generation_test.GreedySearchTextGenerationTest testMethod=test_model_compile_batched_ds_jit_compile_false>
jit_compile = False

    @parameterized.named_parameters(
        ("jit_compile_false", False), ("jit_compile_true", True)
    )
    def test_model_compile_batched_ds(self, jit_compile):
        def token_probability_fn(inputs):
            prob = tf.constant([[0.0, 0.0, 0.0, 1.0]])
            return tf.repeat(prob, 2, axis=0)
    
        max_length = 5
    
        class TestModel(tf.keras.Model):
            def call(self, inputs, training=False):
                if not training:
                    generated = greedy_search(
                        token_probability_fn,
                        inputs,
                        max_length=max_length,
                        end_token_id=2,
                        pad_token_id=0,
                    )
                    return generated
                else:
                    return inputs
    
        inputs = tf.constant([[0, 1], [1, 2]])
        ds = tf.data.Dataset.from_tensor_slices(inputs).batch(2)
        expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
        expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
    
        model = TestModel()
        model.compile(jit_compile=jit_compile)
    
>       outputs = model.predict(ds)

keras_nlp/utils/text_generation_test.py:183: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py:60: in error_handler
    return fn(*args, **kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py:2033: in predict
    tmp_batch_outputs = self.predict_function(iterator)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py:141: in error_handler
    return fn(*args, **kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:915: in __call__
    result = self._call(*args, **kwds)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:963: in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:785: in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py:2480: in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py:2711: in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py:2627: in _create_graph_function
    func_graph_module.func_graph_from_py_func(
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1141: in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:677: in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7fe9bc565dc0>,), kwargs = {}

    def autograph_handler(*args, **kwargs):
      """Calls a converted version of original_func."""
      # TODO(mdan): Push this block higher in tf.function's call stack.
      try:
        return autograph.converted_call(
            original_func,
            args,
            kwargs,
            options=autograph.ConversionOptions(
                recursive=True,
                optional_features=autograph_options,
                user_requested=True,
            ))
      except Exception as e:  # pylint:disable=broad-except
        if hasattr(e, "ag_error_metadata"):
>         raise e.ag_error_metadata.to_exception(e)
E         tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
E         
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1845, in predict_function  *
E                 return step_function(self, iterator)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1834, in step_function  **
E                 outputs = model.distribute_strategy.run(run_step, args=(data,))
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1312, in run
E                 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2888, in call_for_each_replica
E                 return self._call_for_each_replica(fn, args, kwargs)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3689, in _call_for_each_replica
E                 return fn(*args, **kwargs)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1823, in run_step  **
E                 outputs = model.predict_step(data)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1791, in predict_step
E                 return self(x, training=False)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
E                 return fn(*args, **kwargs)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
E                 return super().__call__(*args, **kwargs)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
E                 return fn(*args, **kwargs)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
E                 outputs = call_fn(inputs, *args, **kwargs)
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 146, in error_handler
E                 raise new_e.with_traceback(e.__traceback__) from None
E             File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
E                 return fn(*args, **kwargs)
E         
E             OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_model" (type TestModel).
E             
E             in user code:
E             
E                 File "/home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 164, in call  *
E                     generated = greedy_search(
E                 File "/home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation.py", line 142, in greedy_search  *
E                     batch_size, length = tf.shape(prompt)
E                 File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 579, in __iter__
E                     self._disallow_iteration()
E                 File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 571, in _disallow_iteration
E                     self._disallow_when_autograph_enabled(
E                 File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 539, in _disallow_when_autograph_enabled
E                     raise errors.OperatorNotAllowedInGraphError(
E             
E                 OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
E             
E             
E             Call arguments received by layer "test_model" (type TestModel):
E               • inputs=tf.Tensor(shape=(None, 2), dtype=int32)
E               • training=False

../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1127: OperatorNotAllowedInGraphError