diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py
index 8470184d808..8c980870d0c 100644
--- a/backends/xnnpack/operators/node_visitor.py
+++ b/backends/xnnpack/operators/node_visitor.py
@@ -271,6 +271,8 @@ def get_per_channel_dtype(
                     if force_fp32
                     else XNNDatatype.xnn_datatype_fp16
                 )
+            elif node_dtype is not None and node_dtype == torch.bfloat16:
+                dtype = XNNDatatype.xnn_datatype_bf16
 
         return dtype
 
diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py
index df6067a7d68..a40be5bfa84 100644
--- a/backends/xnnpack/partition/config/xnnpack_config.py
+++ b/backends/xnnpack/partition/config/xnnpack_config.py
@@ -219,6 +219,7 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
     def _check_node_has_valid_dtype(self, node):
         valid_dtypes = {
             torch.float32,
+            torch.bfloat16,
             torch.float16,
             torch.int8,
             torch.qint8,
diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp
index 445744e9918..0aa92d12799 100644
--- a/backends/xnnpack/runtime/XNNCompiler.cpp
+++ b/backends/xnnpack/runtime/XNNCompiler.cpp
@@ -97,7 +97,10 @@ std::pair<float, float> getOutputMinMax(const NodePtr node) noexcept {
 }
 
 /*
-Converts flatbuffer xnn data type to xnnpack data type
+Converts flatbuffer xnn data type to xnnpack data type. 
+
+NOTE:
+Flatbuffer Enum Values are not the same as XNNPACK's datatype enum values.
 */
 xnn_datatype getDataType(const DataType& data_type) {
   switch (data_type) {
@@ -121,6 +124,14 @@ xnn_datatype getDataType(const DataType& data_type) {
       return xnn_datatype::xnn_datatype_qdint8;
     case DataType::xnn_datatype_qbint4:
       return xnn_datatype::xnn_datatype_qbint4;
+    case DataType::xnn_datatype_qpint8:
+      return xnn_datatype::xnn_datatype_qpint8;
+    case DataType::xnn_datatype_int32:
+      return xnn_datatype::xnn_datatype_int32;
+    case DataType::xnn_datatype_pfp32:
+      return xnn_datatype::xnn_datatype_pfp32;
+    case DataType::xnn_datatype_bf16:
+      return xnn_datatype::xnn_datatype_bf16;
     default:
       return xnn_datatype::xnn_datatype_invalid;
   }
@@ -1888,42 +1899,6 @@ Error defineStaticSliceNode(
   return Error::Ok;
 }
 
-/*
-Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
-using the remapped ids to map the serialized ids,
-to the new ids generated when defining the tensor value
-*/
-Error defineScaledDotProductAttentionNode(
-    xnn_subgraph_t subgraph_ptr,
-    const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
-    const NodePtr node,
-    const fb_xnnpack::XNNGraph* graph) noexcept {
-  MAYBE_UNUSED(graph);
-
-  auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();
-
-  xnn_status status = xnn_define_scaled_dot_product_attention(
-      subgraph_ptr,
-      xnn_attention_logits_cap_type_none, // cap_type
-      nullptr, // cap_value - not used
-      remapped_ids.at(graph_node->query_id()),
-      remapped_ids.at(graph_node->key_id()),
-      remapped_ids.at(graph_node->value_id()),
-      remapped_ids.at(graph_node->scale_id()),
-      remapped_ids.at(graph_node->mask_id()),
-      remapped_ids.at(graph_node->output_id()),
-      graph_node->flags());
-
-  ET_CHECK_OR_RETURN_ERROR(
-      status == xnn_status_success,
-      Internal,
-      "Failed to create SDPA node %i with code: %s",
-      node->debug_handle(),
-      xnn_status_to_string(status));
-
-  return Error::Ok;
-}
-
 /*
 Defines batch matrix multiply node into the subgraph,
 using the remapped ids to map the serialized ids,
@@ -2023,7 +1998,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
     _DEFINE(Concatenate4)
     _DEFINE(Concatenate5)
     _DEFINE(StaticSlice)
-    _DEFINE(ScaledDotProductAttention)
     _DEFINE(BatchMatrixMultiply)
     case fb_xnnpack::XNodeUnion::NONE:
     default: // Adding here as a catch all, just in case
diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs
index f10ba3d1b81..99f9e4e5fbd 100644
--- a/backends/xnnpack/serialization/runtime_schema.fbs
+++ b/backends/xnnpack/serialization/runtime_schema.fbs
@@ -29,6 +29,15 @@ enum XNNDatatype : short {
   xnn_datatype_qdint8 = 9,
   /// Quantized 4-bit signed integer with shared blockwise quantization parameters.
   xnn_datatype_qbint4 = 10,
+  /// Dynamically quantized 8-bit signed integers packed with their per-row
+  /// quantization parameters.
+  xnn_datatype_qpint8 = 11,
+  /// 32-bit signed integers.
+  xnn_datatype_int32 = 12,
+  /// IEEE754 single-precision packed floating-point.
+  xnn_datatype_pfp32 = 13,
+  /// BFloat16, i.e. the upper 16 bits of a float32.
+  xnn_datatype_bf16 = 14,
 }
 
 // type of quantization
diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs
index 565eb4c3bba..e3ed4061e94 100644
--- a/backends/xnnpack/serialization/schema.fbs
+++ b/backends/xnnpack/serialization/schema.fbs
@@ -29,6 +29,15 @@ enum XNNDatatype : short {
   xnn_datatype_qdint8 = 9,
   /// Quantized 4-bit signed integer with shared blockwise quantization parameters.
   xnn_datatype_qbint4 = 10,
+  /// Dynamically quantized 8-bit signed integers packed with their per-row
+  /// quantization parameters.
+  xnn_datatype_qpint8 = 11,
+  /// 32-bit signed integers.
+  xnn_datatype_int32 = 12,
+  /// IEEE754 single-precision packed floating-point.
+  xnn_datatype_pfp32 = 13,
+  /// BFloat16, i.e. the upper 16 bits of a float32.
+  xnn_datatype_bf16 = 14,
 }
 
 // type of quantization
diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py
index 2a3ccaf2a0a..4e23e199dec 100644
--- a/backends/xnnpack/serialization/xnnpack_graph_schema.py
+++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py
@@ -413,6 +413,10 @@ class XNNDatatype(IntEnum):
     xnn_datatype_qcint4 = 8
     xnn_datatype_qdint8 = 9
     xnn_datatype_qbint4 = 10
+    xnn_datatype_qpint8 = 11
+    xnn_datatype_int32 = 12
+    xnn_datatype_pfp32 = 13
+    xnn_datatype_bf16 = 14
 
 
 @dataclass
diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py
index 421e59c0b08..dcdd05633b3 100644
--- a/backends/xnnpack/test/ops/test_linear.py
+++ b/backends/xnnpack/test/ops/test_linear.py
@@ -63,7 +63,7 @@ def __init__(
         self.ic = input_channels
         self.oc = output_channels
 
-        assert dtype in [torch.float, torch.half], "Unsupported op dtype"
+        assert dtype in [torch.bfloat16, torch.float, torch.half], "Unsupported op dtype"
         self.op_dtype = dtype
         self.in_size = in_size
 
@@ -432,6 +432,7 @@ def _test_groupwise_dq_linear(
             )
             .to_executorch()
             .serialize()
+            .dump_artifact("/Users/maxren/Desktop/oss/executorch/linear_qd8_bf16.pte")
             .run_method_and_compare_outputs(atol=atol, rtol=rtol)
         )
 
@@ -676,7 +677,6 @@ def _test_qd8_per_token_weight_per_channel_group_int4(
         M_sizes = [1, 2, 17, 31]
         K_sizes = [32, 32, 64, 128]
         bl_sizes = [32, 32, 32, 64]
-        N_sizes = [2, 17, 92, 128]
 
         for input_rank in range(2, 4):
             for use_bias in [True, False]:
@@ -831,6 +831,9 @@ def test_linear_qd8_f16_per_token_weight_per_channel_group_int4(self):
     def test_linear_qd8_f32_per_token_weight_per_channel_group_int4(self):
         self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.float)
 
+    def test_linear_qd8_bf16_per_token_weight_per_channel_group_int4(self):
+        self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.bfloat16)
+
     @unittest.skipIf(
         not torchao_installed, "Per Channel Group Quantization Required TorchAO"
     )
diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py
index dcdafebd6fd..0f54ea01eda 100644
--- a/backends/xnnpack/test/tester/tester.py
+++ b/backends/xnnpack/test/tester/tester.py
@@ -536,7 +536,7 @@ def fn(x):
             random_inputs.append(
                 torch.randn(input_shapes[arg_idx]).to(
                     dtype=self.example_inputs[arg_idx].dtype
-                )
+                )*100
             )
 
         yield tuple(random_inputs)
@@ -714,6 +714,9 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
             assert (
                 ref.shape == model.shape
             ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
+            print(f"actual dtype: {model.dtype}, ref dtype: {ref.dtype}")
+            print(model)
+            print(ref)
             assert torch.allclose(
                 model,
                 ref,
@@ -773,6 +776,8 @@ def _calculate_reference_output(
         return the quantization scale as well.
         """
 
+        cqp = torch.ops.torchao.choose_qparams_affine.default(*inputs, 'ASYMMETRIC', [1, 32], torch.int8, None, None, None, torch.float32, torch.int8)
+        print(f"inv_scale: {1/cqp[0]}, zero_point: {cqp[1]}")
         # Locate the output node.
         output_node = None
         for node in program.graph.nodes:
diff --git a/backends/xnnpack/third-party/XNNPACK b/backends/xnnpack/third-party/XNNPACK
index 4ea82e595b3..4b106fa6089 160000
--- a/backends/xnnpack/third-party/XNNPACK
+++ b/backends/xnnpack/third-party/XNNPACK
@@ -1 +1 @@
-Subproject commit 4ea82e595b36106653175dcb04b2aa532660d0d8
+Subproject commit 4b106fa60892b33b2bddba11dbdb64550e2dfb3a
diff --git a/backends/xnnpack/third-party/pthreadpool b/backends/xnnpack/third-party/pthreadpool
index 4fe0e1e1839..dcc9f285890 160000
--- a/backends/xnnpack/third-party/pthreadpool
+++ b/backends/xnnpack/third-party/pthreadpool
@@ -1 +1 @@
-Subproject commit 4fe0e1e183925bf8cfa6aae24237e724a96479b8
+Subproject commit dcc9f28589066af0dbd4555579281230abbf74dd
diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py
index 3a3102886f8..17e93f0ba4e 100644
--- a/examples/models/llama/export_llama_lib.py
+++ b/examples/models/llama/export_llama_lib.py
@@ -815,11 +815,11 @@ def _to_edge_and_lower_llama_xnnpack(
 
     modelname = f"xnnpack_dq_{modelname}"
 
-    if xnnpack_extended_ops:
-        partitioners.append(
-            get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
-        )
-        modelname = f"xnnpack_{modelname}"
+    # if xnnpack_extended_ops:
+    #     partitioners.append(
+    #         get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
+    #     )
+    #     modelname = f"xnnpack_{modelname}"
 
     logging.info("Lowering model using following partitioner(s): ")
     for partitioner in partitioners: