=================================================================== FAILURES ===================================================================
________________________________ GreedySearchTextGenerationTest.test_model_compile_batched_ds_jit_compile_false ________________________________
self = <keras_nlp.utils.text_generation_test.GreedySearchTextGenerationTest testMethod=test_model_compile_batched_ds_jit_compile_false>
jit_compile = False
@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_model_compile_batched_ds(self, jit_compile):
def token_probability_fn(inputs):
prob = tf.constant([[0.0, 0.0, 0.0, 1.0]])
return tf.repeat(prob, 2, axis=0)
max_length = 5
class TestModel(tf.keras.Model):
def call(self, inputs, training=False):
if not training:
generated = greedy_search(
token_probability_fn,
inputs,
max_length=max_length,
end_token_id=2,
pad_token_id=0,
)
return generated
else:
return inputs
inputs = tf.constant([[0, 1], [1, 2]])
ds = tf.data.Dataset.from_tensor_slices(inputs).batch(2)
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
model = TestModel()
model.compile(jit_compile=jit_compile)
> outputs = model.predict(ds)
keras_nlp/utils/text_generation_test.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py:60: in error_handler
return fn(*args, **kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py:2033: in predict
tmp_batch_outputs = self.predict_function(iterator)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py:141: in error_handler
return fn(*args, **kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:915: in __call__
result = self._call(*args, **kwds)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:963: in _call
self._initialize(args, kwds, add_initializers_to=initializers)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:785: in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py:2480: in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py:2711: in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py:2627: in _create_graph_function
func_graph_module.func_graph_from_py_func(
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1141: in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:677: in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7fe9bc565dc0>,), kwargs = {}
def autograph_handler(*args, **kwargs):
"""Calls a converted version of original_func."""
# TODO(mdan): Push this block higher in tf.function's call stack.
try:
return autograph.converted_call(
original_func,
args,
kwargs,
options=autograph.ConversionOptions(
recursive=True,
optional_features=autograph_options,
user_requested=True,
))
except Exception as e: # pylint:disable=broad-except
if hasattr(e, "ag_error_metadata"):
> raise e.ag_error_metadata.to_exception(e)
E tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
E
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1845, in predict_function *
E return step_function(self, iterator)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1834, in step_function **
E outputs = model.distribute_strategy.run(run_step, args=(data,))
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1312, in run
E return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2888, in call_for_each_replica
E return self._call_for_each_replica(fn, args, kwargs)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3689, in _call_for_each_replica
E return fn(*args, **kwargs)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1823, in run_step **
E outputs = model.predict_step(data)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1791, in predict_step
E return self(x, training=False)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
E return fn(*args, **kwargs)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
E return super().__call__(*args, **kwargs)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
E return fn(*args, **kwargs)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
E outputs = call_fn(inputs, *args, **kwargs)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 146, in error_handler
E raise new_e.with_traceback(e.__traceback__) from None
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
E return fn(*args, **kwargs)
E
E OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_model" (type TestModel).
E
E in user code:
E
E File "/home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 164, in call *
E generated = greedy_search(
E File "/home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation.py", line 142, in greedy_search *
E batch_size, length = tf.shape(prompt)
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 579, in __iter__
E self._disallow_iteration()
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 571, in _disallow_iteration
E self._disallow_when_autograph_enabled(
E File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 539, in _disallow_when_autograph_enabled
E raise errors.OperatorNotAllowedInGraphError(
E
E OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
E
E
E Call arguments received by layer "test_model" (type TestModel):
E • inputs=tf.Tensor(shape=(None, 2), dtype=int32)
E • training=False
../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1127: OperatorNotAllowedInGraphError