E                                           tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
E                                           
E                                           Input and output shapes of loop body do not match: (s32[], s32[], s32[2,2]) vs. (s32[], s32[], s32[2,3])
E                                           
E                                           Stack trace for op definition: 
E                                           File "home/abheesht/python_envs/keras_nlp/bin/pytest", line 8, in <module>
E                                             sys.exit(console_main())
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/config/__init__.py", line 187, in console_main
E                                             code = main()
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/config/__init__.py", line 164, in main
E                                             ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
E                                             return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
E                                             return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
E                                             res = hook_impl.function(*args)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 315, in pytest_cmdline_main
E                                             return wrap_session(config, _main)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 268, in wrap_session
E                                             session.exitstatus = doit(config, session) or 0
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 322, in _main
E                                             config.hook.pytest_runtestloop(session=session)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
E                                             return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
E                                             return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
E                                             res = hook_impl.function(*args)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 347, in pytest_runtestloop
E                                             item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
E                                             return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
E                                             return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
E                                             res = hook_impl.function(*args)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 111, in pytest_runtest_protocol
E                                             runtestprotocol(item, nextitem=nextitem)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 130, in runtestprotocol
E                                             reports.append(call_and_report(item, "call", log))
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 219, in call_and_report
E                                             call = call_runtest_hook(item, when, **kwds)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 258, in call_runtest_hook
E                                             return CallInfo.from_call(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 338, in from_call
E                                             result: Optional[TResult] = func()
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 259, in <lambda>
E                                             lambda: ihook(item=item, **kwds), when=when, reraise=reraise
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
E                                             return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
E                                             return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
E                                             res = hook_impl.function(*args)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 166, in pytest_runtest_call
E                                             item.runtest()
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/unittest.py", line 327, in runtest
E                                             self._testcase(result=self)  # type: ignore[arg-type]
E                                           File "usr/lib/python3.8/unittest/case.py", line 736, in __call__
E                                             return self.run(*args, **kwds)
E                                           File "usr/lib/python3.8/unittest/case.py", line 676, in run
E                                             self._callTestMethod(testMethod)
E                                           File "usr/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
E                                             method()
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/absl/testing/parameterized.py", line 316, in bound_param_test
E                                             return test_method(self, *testcase_params)
E                                           File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 148, in test_model_compile
E                                             outputs = model.predict(inputs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
E                                             return fn(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 2033, in predict
E                                             tmp_batch_outputs = self.predict_function(iterator)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 141, in error_handler
E                                             return fn(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 915, in __call__
E                                             result = self._call(*args, **kwds)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 963, in _call
E                                             self._initialize(args, kwds, add_initializers_to=initializers)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
E                                             self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2480, in _get_concrete_function_internal_garbage_collected
E                                             graph_function, _ = self._maybe_define_function(args, kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
E                                             graph_function = self._create_graph_function(args, kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
E                                             func_graph_module.func_graph_from_py_func(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
E                                             func_outputs = python_func(*func_args, **func_kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
E                                             out = weak_wrapped_fn().__wrapped__(*args, **kwds)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1116, in autograph_handler
E                                             return autograph.converted_call(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1845, in predict_function
E                                             return step_function(self, iterator)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1834, in step_function
E                                             outputs = model.distribute_strategy.run(run_step, args=(data,))
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1312, in run
E                                             return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2888, in call_for_each_replica
E                                             return self._call_for_each_replica(fn, args, kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3689, in _call_for_each_replica
E                                             return fn(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 141, in error_handler
E                                             return fn(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 915, in __call__
E                                             result = self._call(*args, **kwds)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 963, in _call
E                                             self._initialize(args, kwds, add_initializers_to=initializers)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
E                                             self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2480, in _get_concrete_function_internal_garbage_collected
E                                             graph_function, _ = self._maybe_define_function(args, kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
E                                             graph_function = self._create_graph_function(args, kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
E                                             func_graph_module.func_graph_from_py_func(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
E                                             func_outputs = python_func(*func_args, **func_kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
E                                             out = weak_wrapped_fn().__wrapped__(*args, **kwds)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1116, in autograph_handler
E                                             return autograph.converted_call(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1823, in run_step
E                                             outputs = model.predict_step(data)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1791, in predict_step
E                                             return self(x, training=False)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
E                                             return fn(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
E                                             return super().__call__(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
E                                             return fn(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
E                                             outputs = call_fn(inputs, *args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
E                                             return fn(*args, **kwargs)
E                                           File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 129, in call
E                                             if not training:
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1341, in if_stmt
E                                             _py_if_stmt(cond, body, orelse)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1394, in _py_if_stmt
E                                             return body() if cond else orelse()
E                                           File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 130, in call
E                                             generated = greedy_search(
E                                           File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation.py", line 150, in greedy_search
E                                             prompt = tf.while_loop(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 629, in new_func
E                                             return func(*args, **kwargs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2507, in while_loop_v2
E                                             return while_loop(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2705, in while_loop
E                                             return while_v2.while_loop(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 298, in while_loop
E                                             outputs = _build_while_op(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 499, in _build_while_op
E                                             return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_util_v2.py", line 376, in run_as_function_for_tape_gradients
E                                             return make_op(inputs)
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 475, in _make_op
E                                             while_op, tensors = util.get_op_and_outputs(op_fn(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/gen_functional_ops.py", line 1007, in stateless_while
E                                             _, _, _op, _outputs = _op_def_library._apply_op_helper(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 797, in _apply_op_helper
E                                             op = g._create_op_internal(op_type_name, inputs, dtypes=None,
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 694, in _create_op_internal
E                                             return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3754, in _create_op_internal
E                                             ret = Operation(
E                                           File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 2133, in __init__
E                                             self._traceback = tf_stack.extract_stack_for_node(self._c_op)
E                                           
E                                                [[{{node test_model/while}}]]
E                                                [[StatefulPartitionedCall]] [Op:__inference_predict_function_1960]

../../python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/execute.py:54: InvalidArgumentError
------------------------------------------------------------- Captured stderr call -------------------------------------------------------------
2022-07-22 19:46:08.211510: I tensorflow/compiler/xla/service/service.cc:170] XLA service 0x6570110 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-07-22 19:46:08.211578: I tensorflow/compiler/xla/service/service.cc:178]   StreamExecutor device (0): Host, Default Version
2022-07-22 19:46:08.424551: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at while_op.cc:475 : INVALID_ARGUMENT: Input and output shapes of loop body do not match: (s32[], s32[], s32[2,2]) vs. (s32[], s32[], s32[2,3])

Stack trace for op definition: 
File "home/abheesht/python_envs/keras_nlp/bin/pytest", line 8, in <module>
  sys.exit(console_main())
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/config/__init__.py", line 187, in console_main
  code = main()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/config/__init__.py", line 164, in main
  ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 315, in pytest_cmdline_main
  return wrap_session(config, _main)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 268, in wrap_session
  session.exitstatus = doit(config, session) or 0
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 322, in _main
  config.hook.pytest_runtestloop(session=session)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 347, in pytest_runtestloop
  item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 111, in pytest_runtest_protocol
  runtestprotocol(item, nextitem=nextitem)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 130, in runtestprotocol
  reports.append(call_and_report(item, "call", log))
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 219, in call_and_report
  call = call_runtest_hook(item, when, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 258, in call_runtest_hook
  return CallInfo.from_call(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 338, in from_call
  result: Optional[TResult] = func()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 259, in <lambda>
  lambda: ihook(item=item, **kwds), when=when, reraise=reraise
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 166, in pytest_runtest_call
  item.runtest()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/unittest.py", line 327, in runtest
  self._testcase(result=self)  # type: ignore[arg-type]
File "usr/lib/python3.8/unittest/case.py", line 736, in __call__
  return self.run(*args, **kwds)
File "usr/lib/python3.8/unittest/case.py", line 676, in run
  self._callTestMethod(testMethod)
File "usr/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
  method()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/absl/testing/parameterized.py", line 316, in bound_param_test
  return test_method(self, *testcase_params)
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 148, in test_model_compile
  outputs = model.predict(inputs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 2033, in predict
  tmp_batch_outputs = self.predict_function(iterator)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 141, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 915, in __call__
  result = self._call(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 963, in _call
  self._initialize(args, kwds, add_initializers_to=initializers)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
  self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2480, in _get_concrete_function_internal_garbage_collected
  graph_function, _ = self._maybe_define_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
  graph_function = self._create_graph_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
  func_graph_module.func_graph_from_py_func(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
  func_outputs = python_func(*func_args, **func_kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
  out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1116, in autograph_handler
  return autograph.converted_call(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1845, in predict_function
  return step_function(self, iterator)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1834, in step_function
  outputs = model.distribute_strategy.run(run_step, args=(data,))
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1312, in run
  return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2888, in call_for_each_replica
  return self._call_for_each_replica(fn, args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3689, in _call_for_each_replica
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 141, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 915, in __call__
  result = self._call(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 963, in _call
  self._initialize(args, kwds, add_initializers_to=initializers)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
  self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2480, in _get_concrete_function_internal_garbage_collected
  graph_function, _ = self._maybe_define_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
  graph_function = self._create_graph_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
  func_graph_module.func_graph_from_py_func(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
  func_outputs = python_func(*func_args, **func_kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
  out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1116, in autograph_handler
  return autograph.converted_call(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1823, in run_step
  outputs = model.predict_step(data)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1791, in predict_step
  return self(x, training=False)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
  return super().__call__(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 129, in call
  if not training:
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1341, in if_stmt
  _py_if_stmt(cond, body, orelse)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1394, in _py_if_stmt
  return body() if cond else orelse()
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 130, in call
  generated = greedy_search(
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation.py", line 150, in greedy_search
  prompt = tf.while_loop(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 629, in new_func
  return func(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2507, in while_loop_v2
  return while_loop(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2705, in while_loop
  return while_v2.while_loop(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 298, in while_loop
  outputs = _build_while_op(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 499, in _build_while_op
  return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_util_v2.py", line 376, in run_as_function_for_tape_gradients
  return make_op(inputs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 475, in _make_op
  while_op, tensors = util.get_op_and_outputs(op_fn(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/gen_functional_ops.py", line 1007, in stateless_while
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 797, in _apply_op_helper
  op = g._create_op_internal(op_type_name, inputs, dtypes=None,
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 694, in _create_op_internal
  return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3754, in _create_op_internal
  ret = Operation(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 2133, in __init__
  self._traceback = tf_stack.extract_stack_for_node(self._c_op)

2022-07-22 19:46:08.425257: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at xla_ops.cc:296 : INVALID_ARGUMENT: Input and output shapes of loop body do not match: (s32[], s32[], s32[2,2]) vs. (s32[], s32[], s32[2,3])

Stack trace for op definition: 
File "home/abheesht/python_envs/keras_nlp/bin/pytest", line 8, in <module>
  sys.exit(console_main())
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/config/__init__.py", line 187, in console_main
  code = main()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/config/__init__.py", line 164, in main
  ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 315, in pytest_cmdline_main
  return wrap_session(config, _main)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 268, in wrap_session
  session.exitstatus = doit(config, session) or 0
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 322, in _main
  config.hook.pytest_runtestloop(session=session)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/main.py", line 347, in pytest_runtestloop
  item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 111, in pytest_runtest_protocol
  runtestprotocol(item, nextitem=nextitem)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 130, in runtestprotocol
  reports.append(call_and_report(item, "call", log))
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 219, in call_and_report
  call = call_runtest_hook(item, when, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 258, in call_runtest_hook
  return CallInfo.from_call(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 338, in from_call
  result: Optional[TResult] = func()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 259, in <lambda>
  lambda: ihook(item=item, **kwds), when=when, reraise=reraise
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_hooks.py", line 265, in __call__
  return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_manager.py", line 80, in _hookexec
  return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/pluggy/_callers.py", line 39, in _multicall
  res = hook_impl.function(*args)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/runner.py", line 166, in pytest_runtest_call
  item.runtest()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/_pytest/unittest.py", line 327, in runtest
  self._testcase(result=self)  # type: ignore[arg-type]
File "usr/lib/python3.8/unittest/case.py", line 736, in __call__
  return self.run(*args, **kwds)
File "usr/lib/python3.8/unittest/case.py", line 676, in run
  self._callTestMethod(testMethod)
File "usr/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
  method()
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/absl/testing/parameterized.py", line 316, in bound_param_test
  return test_method(self, *testcase_params)
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 148, in test_model_compile
  outputs = model.predict(inputs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 2033, in predict
  tmp_batch_outputs = self.predict_function(iterator)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 141, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 915, in __call__
  result = self._call(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 963, in _call
  self._initialize(args, kwds, add_initializers_to=initializers)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
  self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2480, in _get_concrete_function_internal_garbage_collected
  graph_function, _ = self._maybe_define_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
  graph_function = self._create_graph_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
  func_graph_module.func_graph_from_py_func(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
  func_outputs = python_func(*func_args, **func_kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
  out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1116, in autograph_handler
  return autograph.converted_call(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1845, in predict_function
  return step_function(self, iterator)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1834, in step_function
  outputs = model.distribute_strategy.run(run_step, args=(data,))
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1312, in run
  return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2888, in call_for_each_replica
  return self._call_for_each_replica(fn, args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3689, in _call_for_each_replica
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 141, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 915, in __call__
  result = self._call(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 963, in _call
  self._initialize(args, kwds, add_initializers_to=initializers)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
  self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2480, in _get_concrete_function_internal_garbage_collected
  graph_function, _ = self._maybe_define_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
  graph_function = self._create_graph_function(args, kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
  func_graph_module.func_graph_from_py_func(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
  func_outputs = python_func(*func_args, **func_kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
  out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 1116, in autograph_handler
  return autograph.converted_call(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1823, in run_step
  outputs = model.predict_step(data)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 1791, in predict_step
  return self(x, training=False)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
  return super().__call__(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 60, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 129, in call
  if not training:
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1341, in if_stmt
  _py_if_stmt(cond, body, orelse)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1394, in _py_if_stmt
  return body() if cond else orelse()
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation_test.py", line 130, in call
  generated = greedy_search(
File "home/abheesht/repos/keras-nlp/keras_nlp/utils/text_generation.py", line 150, in greedy_search
  prompt = tf.while_loop(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 629, in new_func
  return func(*args, **kwargs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2507, in while_loop_v2
  return while_loop(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2705, in while_loop
  return while_v2.while_loop(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 298, in while_loop
  outputs = _build_while_op(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 499, in _build_while_op
  return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_util_v2.py", line 376, in run_as_function_for_tape_gradients
  return make_op(inputs)
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 475, in _make_op
  while_op, tensors = util.get_op_and_outputs(op_fn(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/ops/gen_functional_ops.py", line 1007, in stateless_while
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 797, in _apply_op_helper
  op = g._create_op_internal(op_type_name, inputs, dtypes=None,
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 694, in _create_op_internal
  return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3754, in _create_op_internal
  ret = Operation(
File "home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 2133, in __init__
  self._traceback = tf_stack.extract_stack_for_node(self._c_op)

         [[{{node test_model/while}}]]
=========================================================== short test summary info ============================================================
FAILED keras_nlp/utils/text_generation_test.py::GreedySearchTextGenerationTest::test_model_compile_jit_compile_true - tensorflow.python.frame...
=================================================== 1 failed, 45 passed, 5 skipped in 31.70s ===================================================
(keras_nlp) abheesht@LAPTOP-M2NKFTLU:~/repos/keras-nlp$