Basic support for the ace step 1.5 model. (#12237)
This commit is contained in:
@@ -103,6 +103,52 @@ class Qwen3_06BConfig:
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06B_ACE15_Config:
|
||||
vocab_size: int = 151669
|
||||
hidden_size: int = 1024
|
||||
intermediate_size: int = 3072
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_2B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
vocab_size: int = 151936
|
||||
@@ -729,6 +775,27 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06B_ACE15_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def logits(self, x):
|
||||
return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None)
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user